diff --git a/alembic/versions/74f2ede29317_add_background_group_support.py b/alembic/versions/74f2ede29317_add_background_group_support.py new file mode 100644 index 00000000..a7657ec0 --- /dev/null +++ b/alembic/versions/74f2ede29317_add_background_group_support.py @@ -0,0 +1,44 @@ +"""add background group support + +Revision ID: 74f2ede29317 +Revises: bff040379479 +Create Date: 2025-04-01 07:45:31.735977 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "74f2ede29317" +down_revision: Union[str, None] = "bff040379479" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("groups", sa.Column("background_agents_interval", sa.Integer(), nullable=True)) + op.add_column("groups", sa.Column("turns_counter", sa.Integer(), nullable=True)) + op.add_column("groups", sa.Column("last_processed_message_id", sa.String(), nullable=True)) + op.create_table( + "groups_blocks", + sa.Column("group_id", sa.String(), nullable=False), + sa.Column("block_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["block_id"], ["block.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["group_id"], ["groups.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("group_id", "block_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("groups_blocks") + op.drop_column("groups", "last_processed_message_id") + op.drop_column("groups", "turns_counter") + op.drop_column("groups", "background_agents_interval") + # ### end Alembic commands ### diff --git a/letta/groups/background_multi_agent.py b/letta/groups/background_multi_agent.py new file mode 100644 index 00000000..49a7567f --- /dev/null +++ b/letta/groups/background_multi_agent.py @@ -0,0 +1,254 @@ +import asyncio +import threading +from datetime import datetime +from typing import List, Optional + +from letta.agent import Agent, AgentState +from letta.groups.helpers import stringify_message +from letta.interface import AgentInterface +from letta.orm import User +from letta.schemas.enums import JobStatus +from letta.schemas.job import JobUpdate +from letta.schemas.letta_message_content import TextContent +from letta.schemas.message import Message, MessageCreate +from letta.schemas.run import Run +from letta.schemas.usage import LettaUsageStatistics +from letta.services.group_manager import GroupManager +from letta.services.job_manager import JobManager +from letta.services.message_manager import MessageManager + + +class BackgroundMultiAgent(Agent): + + def __init__( + self, + interface: AgentInterface, + agent_state: AgentState, + user: User, + # custom + group_id: str = "", + agent_ids: List[str] = [], + description: str = "", + background_agents_interval: Optional[int] = None, + ): + super().__init__(interface, agent_state, user) + self.group_id = group_id + self.agent_ids = agent_ids + self.description = description + self.background_agents_interval = background_agents_interval + self.group_manager = GroupManager() + self.message_manager = MessageManager() + self.job_manager = JobManager() + + def _run_async_in_new_thread(self, coro): + """Run an async coroutine in a new thread with its own event loop""" + result = None + + def run_async(): + nonlocal result + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + result = loop.run_until_complete(coro) + finally: + loop.close() + asyncio.set_event_loop(None) + + thread = threading.Thread(target=run_async) + thread.start() + thread.join() + return result + + async def _issue_background_task( + self, + participant_agent_id: str, + messages: List[Message], + chaining: bool, + max_chaining_steps: Optional[int], + token_streaming: bool, + metadata: Optional[dict], + put_inner_thoughts_first: bool, + last_processed_message_id: str, + ) -> str: + run = Run( + user_id=self.user.id, + status=JobStatus.created, + metadata={ + "job_type": "background_agent_send_message_async", + "agent_id": participant_agent_id, + }, + ) + run = self.job_manager.create_job(pydantic_job=run, actor=self.user) + + asyncio.create_task( + self._perform_background_agent_step( + participant_agent_id=participant_agent_id, + messages=messages, + chaining=chaining, + max_chaining_steps=max_chaining_steps, + token_streaming=token_streaming, + metadata=metadata, + put_inner_thoughts_first=put_inner_thoughts_first, + last_processed_message_id=last_processed_message_id, + run_id=run.id, + ) + ) + + return run.id + + async def _perform_background_agent_step( + self, + participant_agent_id: str, + messages: List[Message], + chaining: bool, + max_chaining_steps: Optional[int], + token_streaming: bool, + metadata: Optional[dict], + put_inner_thoughts_first: bool, + last_processed_message_id: str, + run_id: str, + ) -> LettaUsageStatistics: + try: + participant_agent_state = self.agent_manager.get_agent_by_id(participant_agent_id, actor=self.user) + participant_agent = Agent( + agent_state=participant_agent_state, + interface=self.interface, + user=self.user, + ) + + prior_messages = [] + if self.background_agents_interval: + try: + prior_messages = self.message_manager.list_messages_for_agent( + agent_id=self.agent_state.id, + actor=self.user, + after=last_processed_message_id, + before=messages[0].id, + ) + except Exception as e: + print(f"Error fetching prior messages: {str(e)}") + # continue with just latest messages + + transcript_summary = [stringify_message(message) for message in prior_messages + messages] + transcript_summary = [summary for summary in transcript_summary if summary is not None] + message_text = "\n".join(transcript_summary) + + participant_agent_messages = [ + Message( + id=Message.generate_id(), + agent_id=participant_agent.agent_state.id, + role="user", + content=[TextContent(text=message_text)], + group_id=self.group_id, + ) + ] + result = participant_agent.step( + messages=participant_agent_messages, + chaining=chaining, + max_chaining_steps=max_chaining_steps, + stream=token_streaming, + skip_verify=True, + metadata=metadata, + put_inner_thoughts_first=put_inner_thoughts_first, + ) + job_update = JobUpdate( + status=JobStatus.completed, + completed_at=datetime.utcnow(), + metadata={"result": result.model_dump(mode="json")}, # Store the result in metadata + ) + self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.user) + return result + except Exception as e: + job_update = JobUpdate( + status=JobStatus.failed, + completed_at=datetime.utcnow(), + metadata={"error": str(e)}, + ) + self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.user) + raise + + def step( + self, + messages: List[MessageCreate], + chaining: bool = True, + max_chaining_steps: Optional[int] = None, + put_inner_thoughts_first: bool = True, + **kwargs, + ) -> LettaUsageStatistics: + run_ids = [] + + token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False + metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None + + messages = [ + Message( + id=Message.generate_id(), + agent_id=self.agent_state.id, + role=message.role, + content=[TextContent(text=message.content)] if isinstance(message.content, str) else message.content, + name=message.name, + model=None, + tool_calls=None, + tool_call_id=None, + group_id=self.group_id, + ) + for message in messages + ] + + try: + main_agent = Agent( + agent_state=self.agent_state, + interface=self.interface, + user=self.user, + ) + usage_stats = main_agent.step( + messages=messages, + chaining=chaining, + max_chaining_steps=max_chaining_steps, + stream=token_streaming, + skip_verify=True, + metadata=metadata, + put_inner_thoughts_first=put_inner_thoughts_first, + ) + + turns_counter = None + if self.background_agents_interval is not None and self.background_agents_interval > 0: + turns_counter = self.group_manager.bump_turns_counter(group_id=self.group_id, actor=self.user) + + if self.background_agents_interval is None or ( + turns_counter is not None and turns_counter % self.background_agents_interval == 0 + ): + last_response_messages = [message for sublist in usage_stats.steps_messages for message in sublist] + last_processed_message_id = self.group_manager.get_last_processed_message_id_and_update( + group_id=self.group_id, last_processed_message_id=last_response_messages[-1].id, actor=self.user + ) + for participant_agent_id in self.agent_ids: + try: + run_id = self._run_async_in_new_thread( + self._issue_background_task( + participant_agent_id, + last_response_messages, + chaining, + max_chaining_steps, + token_streaming, + metadata, + put_inner_thoughts_first, + last_processed_message_id, + ) + ) + run_ids.append(run_id) + + except Exception as e: + # Handle individual task failures + print(f"Agent processing failed: {str(e)}") + raise e + + except Exception as e: + raise e + finally: + self.interface.step_yield() + + self.interface.step_complete() + + usage_stats.run_ids = run_ids + return LettaUsageStatistics(**usage_stats.model_dump()) diff --git a/letta/dynamic_multi_agent.py b/letta/groups/dynamic_multi_agent.py similarity index 100% rename from letta/dynamic_multi_agent.py rename to letta/groups/dynamic_multi_agent.py diff --git a/letta/groups/helpers.py b/letta/groups/helpers.py new file mode 100644 index 00000000..68012836 --- /dev/null +++ b/letta/groups/helpers.py @@ -0,0 +1,104 @@ +import json +from typing import Optional, Union + +from letta.agent import Agent +from letta.interface import AgentInterface +from letta.orm.group import Group +from letta.orm.user import User +from letta.schemas.agent import AgentState +from letta.schemas.group import ManagerType +from letta.schemas.message import Message + + +def load_multi_agent( + group: Group, + agent_state: Optional[AgentState], + actor: User, + interface: Union[AgentInterface, None] = None, +) -> Agent: + if len(group.agent_ids) == 0: + raise ValueError("Empty group: group must have at least one agent") + + if not agent_state: + raise ValueError("Empty manager agent state: manager agent state must be provided") + + match group.manager_type: + case ManagerType.round_robin: + from letta.groups.round_robin_multi_agent import RoundRobinMultiAgent + + return RoundRobinMultiAgent( + agent_state=agent_state, + interface=interface, + user=actor, + group_id=group.id, + agent_ids=group.agent_ids, + description=group.description, + max_turns=group.max_turns, + ) + case ManagerType.dynamic: + from letta.groups.dynamic_multi_agent import DynamicMultiAgent + + return DynamicMultiAgent( + agent_state=agent_state, + interface=interface, + user=actor, + group_id=group.id, + agent_ids=group.agent_ids, + description=group.description, + max_turns=group.max_turns, + termination_token=group.termination_token, + ) + case ManagerType.supervisor: + from letta.groups.supervisor_multi_agent import SupervisorMultiAgent + + return SupervisorMultiAgent( + agent_state=agent_state, + interface=interface, + user=actor, + group_id=group.id, + agent_ids=group.agent_ids, + description=group.description, + ) + case ManagerType.background: + from letta.groups.background_multi_agent import BackgroundMultiAgent + + return BackgroundMultiAgent( + agent_state=agent_state, + interface=interface, + user=actor, + group_id=group.id, + agent_ids=group.agent_ids, + description=group.description, + background_agents_interval=group.background_agents_interval, + ) + case _: + raise ValueError(f"Type {group.manager_type} is not supported.") + + +def stringify_message(message: Message) -> str | None: + if message.role == "user": + content = json.loads(message.content[0].text) + if content["type"] == "user_message": + return f"{message.name or 'user'}: {content['message']}" + else: + return None + elif message.role == "assistant": + messages = [] + if message.content: + messages.append(f"{message.name or 'assistant'}: *thinking* {message.content[0].text}") + if message.tool_calls: + if message.tool_calls[0].function.name == "send_message": + messages.append(f"{message.name or 'assistant'}: {json.loads(message.tool_calls[0].function.arguments)['message']}") + else: + messages.append(f"{message.name or 'assistant'}: Calling tool {message.tool_calls[0].function.name}") + return "\n".join(messages) + elif message.role == "tool": + if message.content: + content = json.loads(message.content[0].text) + if content["message"] != "None" and content["message"] != None: + return f"{message.name or 'assistant'}: Tool call returned {content['message']}" + return None + elif message.role == "system": + return None + + return f"{message.name or 'user'}: {message.content[0].text}" diff --git a/letta/round_robin_multi_agent.py b/letta/groups/round_robin_multi_agent.py similarity index 100% rename from letta/round_robin_multi_agent.py rename to letta/groups/round_robin_multi_agent.py diff --git a/letta/supervisor_multi_agent.py b/letta/groups/supervisor_multi_agent.py similarity index 100% rename from letta/supervisor_multi_agent.py rename to letta/groups/supervisor_multi_agent.py diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index 0dad525c..02af8304 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -7,6 +7,7 @@ from letta.orm.blocks_agents import BlocksAgents from letta.orm.file import FileMetadata from letta.orm.group import Group from letta.orm.groups_agents import GroupsAgents +from letta.orm.groups_blocks import GroupsBlocks from letta.orm.identities_agents import IdentitiesAgents from letta.orm.identities_blocks import IdentitiesBlocks from letta.orm.identity import Identity diff --git a/letta/orm/block.py b/letta/orm/block.py index edd56d0c..30b2f1ab 100644 --- a/letta/orm/block.py +++ b/letta/orm/block.py @@ -67,6 +67,13 @@ class Block(OrganizationMixin, SqlalchemyBase): back_populates="blocks", passive_deletes=True, ) + groups: Mapped[List["Group"]] = relationship( + "Group", + secondary="groups_blocks", + lazy="selectin", + back_populates="shared_blocks", + passive_deletes=True, + ) def to_pydantic(self) -> Type: match self.label: diff --git a/letta/orm/group.py b/letta/orm/group.py index 3599386b..6f939eb0 100644 --- a/letta/orm/group.py +++ b/letta/orm/group.py @@ -20,6 +20,9 @@ class Group(SqlalchemyBase, OrganizationMixin): manager_agent_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("agents.id", ondelete="RESTRICT"), nullable=True, doc="") termination_token: Mapped[Optional[str]] = mapped_column(nullable=True, doc="") max_turns: Mapped[Optional[int]] = mapped_column(nullable=True, doc="") + background_agents_interval: Mapped[Optional[int]] = mapped_column(nullable=True, doc="") + turns_counter: Mapped[Optional[int]] = mapped_column(nullable=True, doc="") + last_processed_message_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="") # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="groups") @@ -27,4 +30,7 @@ class Group(SqlalchemyBase, OrganizationMixin): agents: Mapped[List["Agent"]] = relationship( "Agent", secondary="groups_agents", lazy="selectin", passive_deletes=True, back_populates="groups" ) + shared_blocks: Mapped[List["Block"]] = relationship( + "Block", secondary="groups_blocks", lazy="selectin", passive_deletes=True, back_populates="groups" + ) manager_agent: Mapped["Agent"] = relationship("Agent", lazy="joined", back_populates="multi_agent_group") diff --git a/letta/orm/groups_blocks.py b/letta/orm/groups_blocks.py new file mode 100644 index 00000000..5c5b0205 --- /dev/null +++ b/letta/orm/groups_blocks.py @@ -0,0 +1,13 @@ +from sqlalchemy import ForeignKey, String +from sqlalchemy.orm import Mapped, mapped_column + +from letta.orm.base import Base + + +class GroupsBlocks(Base): + """Groups may have one or many shared blocks associated with them.""" + + __tablename__ = "groups_blocks" + + group_id: Mapped[str] = mapped_column(String, ForeignKey("groups.id", ondelete="CASCADE"), primary_key=True) + block_id: Mapped[str] = mapped_column(String, ForeignKey("block.id", ondelete="CASCADE"), primary_key=True) diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index ee24cb78..f8055263 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -174,6 +174,7 @@ class CreateAgent(BaseModel, validate_assignment=True): # False, description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.", ) + enable_sleeptime: Optional[bool] = Field(False, description="If set to True, memory management will move to a background agent thread.") @field_validator("name") @classmethod @@ -252,6 +253,7 @@ class UpdateAgent(BaseModel): embedding: Optional[str] = Field( None, description="The embedding configuration handle used by the agent, specified in the format provider/model-name." ) + enable_sleeptime: Optional[bool] = Field(False, description="If set to True, memory management will move to a background agent thread.") class Config: extra = "ignore" # Ignores extra fields diff --git a/letta/schemas/group.py b/letta/schemas/group.py index e3d5fadf..4e5b27d2 100644 --- a/letta/schemas/group.py +++ b/letta/schemas/group.py @@ -10,6 +10,7 @@ class ManagerType(str, Enum): round_robin = "round_robin" supervisor = "supervisor" dynamic = "dynamic" + background = "background" swarm = "swarm" @@ -22,10 +23,14 @@ class Group(GroupBase): manager_type: ManagerType = Field(..., description="") agent_ids: List[str] = Field(..., description="") description: str = Field(..., description="") + shared_block_ids: List[str] = Field([], description="") # Pattern fields manager_agent_id: Optional[str] = Field(None, description="") termination_token: Optional[str] = Field(None, description="") max_turns: Optional[int] = Field(None, description="") + background_agents_interval: Optional[int] = Field(None, description="") + turns_counter: Optional[int] = Field(None, description="") + last_processed_message_id: Optional[str] = Field(None, description="") class ManagerConfig(BaseModel): @@ -49,12 +54,18 @@ class DynamicManager(ManagerConfig): max_turns: Optional[int] = Field(None, description="") +class BackgroundManager(ManagerConfig): + manager_type: Literal[ManagerType.background] = Field(ManagerType.background, description="") + manager_agent_id: str = Field(..., description="") + background_agents_interval: Optional[int] = Field(None, description="") + + # class SwarmGroup(ManagerConfig): # manager_type: Literal[ManagerType.swarm] = Field(ManagerType.swarm, description="") ManagerConfigUnion = Annotated[ - Union[RoundRobinManager, SupervisorManager, DynamicManager], + Union[RoundRobinManager, SupervisorManager, DynamicManager, BackgroundManager], Field(discriminator="manager_type"), ] @@ -63,9 +74,11 @@ class GroupCreate(BaseModel): agent_ids: List[str] = Field(..., description="") description: str = Field(..., description="") manager_config: ManagerConfigUnion = Field(RoundRobinManager(), description="") + shared_block_ids: List[str] = Field([], description="") class GroupUpdate(BaseModel): agent_ids: Optional[List[str]] = Field(None, description="") description: Optional[str] = Field(None, description="") manager_config: Optional[ManagerConfigUnion] = Field(None, description="") + shared_block_ids: Optional[List[str]] = Field(None, description="") diff --git a/letta/schemas/usage.py b/letta/schemas/usage.py index c3178bf8..d2f2c688 100644 --- a/letta/schemas/usage.py +++ b/letta/schemas/usage.py @@ -23,3 +23,4 @@ class LettaUsageStatistics(BaseModel): step_count: int = Field(0, description="The number of steps taken by the agent.") # TODO: Optional for now. This field makes everyone's lives easier steps_messages: Optional[List[List[Message]]] = Field(None, description="The messages generated per step") + run_ids: Optional[List[str]] = Field(None, description="The background task run IDs associated with the agent interaction") diff --git a/letta/server/server.py b/letta/server/server.py index 636a5289..6ade4486 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -19,11 +19,11 @@ import letta.system as system from letta.agent import Agent, save_agent from letta.config import LettaConfig from letta.data_sources.connectors import DataConnector, load_data -from letta.dynamic_multi_agent import DynamicMultiAgent from letta.functions.mcp_client.base_client import BaseMCPClient from letta.functions.mcp_client.sse_client import MCP_CONFIG_TOPLEVEL_KEY, SSEMCPClient from letta.functions.mcp_client.stdio_client import StdioMCPClient from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig +from letta.groups.helpers import load_multi_agent from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.json_helpers import json_dumps, json_loads from letta.helpers.message_helper import prepare_input_message_create @@ -34,7 +34,6 @@ from letta.interface import CLIInterface # for printing to terminal from letta.log import get_logger from letta.offline_memory_agent import OfflineMemoryAgent from letta.orm.errors import NoResultFound -from letta.round_robin_multi_agent import RoundRobinMultiAgent from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig @@ -42,7 +41,6 @@ from letta.schemas.embedding_config import EmbeddingConfig # openai schemas from letta.schemas.enums import JobStatus, MessageStreamStatus from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate -from letta.schemas.group import Group, ManagerType from letta.schemas.job import Job, JobUpdate from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, ToolReturnMessage from letta.schemas.letta_message_content import TextContent @@ -94,7 +92,6 @@ from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSan from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager from letta.settings import model_settings, settings, tool_settings -from letta.supervisor_multi_agent import SupervisorMultiAgent from letta.tracing import trace_method from letta.utils import get_friendly_error_msg @@ -352,7 +349,7 @@ class SyncServer(Server): """Updated method to load agents from persisted storage""" agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) if agent_state.multi_agent_group: - return self.load_multi_agent(agent_state.multi_agent_group, actor, interface, agent_state) + return load_multi_agent(group=agent_state.multi_agent_group, agent_state=agent_state, actor=actor, interface=interface) interface = interface or self.default_interface_factory() if agent_state.agent_type == AgentType.memgpt_agent: @@ -364,49 +361,6 @@ class SyncServer(Server): return agent - def load_multi_agent( - self, group: Group, actor: User, interface: Union[AgentInterface, None] = None, agent_state: Optional[AgentState] = None - ) -> Agent: - if len(group.agent_ids) == 0: - raise ValueError("Empty group: group must have at least one agent") - - match group.manager_type: - case ManagerType.round_robin: - agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.agent_ids[0], actor=actor) - return RoundRobinMultiAgent( - agent_state=agent_state, - interface=interface, - user=actor, - group_id=group.id, - agent_ids=group.agent_ids, - description=group.description, - max_turns=group.max_turns, - ) - case ManagerType.dynamic: - agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.manager_agent_id, actor=actor) - return DynamicMultiAgent( - agent_state=agent_state, - interface=interface, - user=actor, - group_id=group.id, - agent_ids=group.agent_ids, - description=group.description, - max_turns=group.max_turns, - termination_token=group.termination_token, - ) - case ManagerType.supervisor: - agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.manager_agent_id, actor=actor) - return SupervisorMultiAgent( - agent_state=agent_state, - interface=interface, - user=actor, - group_id=group.id, - agent_ids=group.agent_ids, - description=group.description, - ) - case _: - raise ValueError(f"Type {group.manager_type} is not supported.") - def _step( self, actor: User, @@ -1599,7 +1553,9 @@ class SyncServer(Server): raise ValueError("stream_steps must be 'true' if stream_tokens is 'true'") group = self.group_manager.retrieve_group(group_id=group_id, actor=actor) - letta_multi_agent = self.load_multi_agent(group=group, actor=actor) + agent_state_id = group.manager_agent_id or (group.agent_ids[0] if len(group.agent_ids) > 0 else None) + agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_state_id, actor=actor) if agent_state_id else None + letta_multi_agent = load_multi_agent(group=group, agent_state=agent_state, actor=actor) llm_config = letta_multi_agent.agent_state.llm_config supports_token_streaming = ["openai", "anthropic", "deepseek"] diff --git a/letta/services/group_manager.py b/letta/services/group_manager.py index 043f7059..b5ef69ea 100644 --- a/letta/services/group_manager.py +++ b/letta/services/group_manager.py @@ -71,11 +71,20 @@ class GroupManager: case ManagerType.supervisor: new_group.manager_type = ManagerType.supervisor new_group.manager_agent_id = group.manager_config.manager_agent_id + case ManagerType.background: + new_group.manager_type = ManagerType.background + new_group.manager_agent_id = group.manager_config.manager_agent_id + new_group.background_agents_interval = group.manager_config.background_agents_interval + if new_group.background_agents_interval: + new_group.turns_counter = 0 case _: raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}") self._process_agent_relationship(session=session, group=new_group, agent_ids=group.agent_ids, allow_partial=False) + if group.shared_block_ids: + self._process_shared_block_relationship(session=session, group=new_group, block_ids=group.shared_block_ids) + new_group.create(session, actor=actor) return new_group.to_pydantic() @@ -84,6 +93,7 @@ class GroupManager: with self.session_maker() as session: group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) + background_agents_interval = None max_turns = None termination_token = None manager_agent_id = None @@ -99,9 +109,16 @@ class GroupManager: termination_token = group_update.manager_config.termination_token case ManagerType.supervisor: manager_agent_id = group_update.manager_config.manager_agent_id + case ManagerType.background: + manager_agent_id = group_update.manager_config.manager_agent_id + background_agents_interval = group_update.manager_config.background_agents_interval + if background_agents_interval and group.turns_counter is None: + group.turns_counter = 0 case _: raise ValueError(f"Unsupported manager type: {group_update.manager_config.manager_type}") + if background_agents_interval: + group.background_agents_interval = background_agents_interval if max_turns: group.max_turns = max_turns if termination_token: @@ -174,6 +191,30 @@ class GroupManager: session.commit() + @enforce_types + def bump_turns_counter(self, group_id: str, actor: PydanticUser) -> int: + with self.session_maker() as session: + # Ensure group is loadable by user + group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) + + # Update turns counter + group.turns_counter = (group.turns_counter + 1) % group.background_agents_interval + group.update(session, actor=actor) + return group.turns_counter + + @enforce_types + def get_last_processed_message_id_and_update(self, group_id: str, last_processed_message_id: str, actor: PydanticUser) -> str: + with self.session_maker() as session: + # Ensure group is loadable by user + group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) + + # Update last processed message id + prev_last_processed_message_id = group.last_processed_message_id + group.last_processed_message_id = last_processed_message_id + group.update(session, actor=actor) + + return prev_last_processed_message_id + def _process_agent_relationship(self, session: Session, group: GroupModel, agent_ids: List[str], allow_partial=False, replace=True): if not agent_ids: if replace: @@ -203,3 +244,30 @@ class GroupManager: setattr(group, "agent_ids", agent_ids) else: raise ValueError("Extend relationship is not supported for groups.") + + def _process_shared_block_relationship( + self, + session: Session, + group: GroupModel, + block_ids: List[str], + ): + """Process shared block relationships for a group and its agents.""" + from letta.orm import Agent, Block, BlocksAgents + + # Add blocks to group + blocks = session.query(Block).filter(Block.id.in_(block_ids)).all() + group.shared_blocks = blocks + + # Add blocks to all agents + if group.agent_ids: + agents = session.query(Agent).filter(Agent.id.in_(group.agent_ids)).all() + for agent in agents: + for block in blocks: + session.add(BlocksAgents(agent_id=agent.id, block_id=block.id, block_label=block.label)) + + # Add blocks to manager agent if exists + if group.manager_agent_id: + manager_agent = session.query(Agent).filter(Agent.id == group.manager_agent_id).first() + if manager_agent: + for block in blocks: + session.add(BlocksAgents(agent_id=manager_agent.id, block_id=block.id, block_label=block.label)) diff --git a/tests/test_multi_agent.py b/tests/test_multi_agent.py index 2214b143..d4349658 100644 --- a/tests/test_multi_agent.py +++ b/tests/test_multi_agent.py @@ -1,12 +1,28 @@ +import time + import pytest from sqlalchemy import delete from letta.config import LettaConfig +from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS +from letta.functions.functions import parse_source_code +from letta.functions.schema_generator import generate_schema from letta.orm import Provider, Step from letta.schemas.agent import CreateAgent -from letta.schemas.block import CreateBlock -from letta.schemas.group import DynamicManager, GroupCreate, GroupUpdate, ManagerType, RoundRobinManager, SupervisorManager +from letta.schemas.block import Block, CreateBlock +from letta.schemas.enums import JobStatus +from letta.schemas.group import ( + BackgroundManager, + DynamicManager, + GroupCreate, + GroupUpdate, + ManagerType, + RoundRobinManager, + SupervisorManager, +) from letta.schemas.message import MessageCreate +from letta.schemas.tool import Tool +from letta.schemas.tool_rule import TerminalToolRule from letta.server.server import SyncServer @@ -426,3 +442,129 @@ async def test_dynamic_group_chat(server, actor, manager_agent, participant_agen finally: server.group_manager.delete_group(group_id=group.id, actor=actor) + + +@pytest.mark.asyncio +async def test_background_group_chat(server, actor): + # 1. create shared block + shared_memory_block = server.block_manager.create_or_update_block( + Block( + label="human", + value="", + limit=1000, + ), + actor=actor, + ) + + # 2. create main agent + main_agent = server.create_agent( + request=CreateAgent( + name="main_agent", + memory_blocks=[ + CreateBlock( + label="persona", + value="You are a personal assistant that helps users with requests.", + ), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-ada-002", + include_base_tools=False, + tools=BASE_TOOLS, + ), + actor=actor, + ) + + # 3. create background memory agent + def skip_memory_update(): + """ + Perform no memory updates. This function is used when the transcript + does not require any changes to the memory. + """ + + skip_memory_update = Tool( + name=skip_memory_update.__name__, + description="", + source_type="python", + tags=[], + source_code=parse_source_code(skip_memory_update), + json_schema=generate_schema(skip_memory_update, None), + ) + skip_memory_update = server.tool_manager.create_or_update_tool( + pydantic_tool=skip_memory_update, + actor=actor, + ) + + background_memory_agent = server.create_agent( + request=CreateAgent( + name="memory_agent", + memory_blocks=[ + CreateBlock( + label="persona", + value="You manage memory for the main agent. You are a background agent and you are not expected to respond to messages. When you receive a conversation snippet from the main thread, perform memory updates only if there are meaningful changes, and otherwise call the skip_memory_update tool.", + ), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-ada-002", + include_base_tools=False, + include_base_tool_rules=False, + tools=BASE_MEMORY_TOOLS + [skip_memory_update.name], + tool_rules=[ + TerminalToolRule(tool_name="core_memory_append"), + TerminalToolRule(tool_name="core_memory_replace"), + TerminalToolRule(tool_name="skip_memory_update"), + ], + ), + actor=actor, + ) + + # 4. create group + group = server.group_manager.create_group( + group=GroupCreate( + description="", + agent_ids=[background_memory_agent.id], + manager_config=BackgroundManager( + manager_agent_id=main_agent.id, + background_agents_interval=2, + ), + shared_block_ids=[shared_memory_block.id], + ), + actor=actor, + ) + + agents = server.block_manager.get_agents_for_block(block_id=shared_memory_block.id, actor=actor) + assert len(agents) == 2 + + message_text = [ + "my favorite color is orange", + "not particularly. today is a good day", + "actually my favorite color is coral", + "sorry gotta run", + ] + run_ids = [] + for i, text in enumerate(message_text): + response = await server.send_message_to_agent( + agent_id=main_agent.id, + actor=actor, + messages=[ + MessageCreate( + role="user", + content=text, + ), + ], + stream_steps=False, + stream_tokens=False, + ) + + assert len(response.messages) > 0 + assert len(response.usage.run_ids) == i % 2 + run_ids.extend(response.usage.run_ids) + + time.sleep(5) + + for run_id in run_ids: + job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor) + assert job.status == JobStatus.completed + + server.group_manager.delete_group(group_id=group.id, actor=actor) + server.agent_manager.delete_agent(agent_id=background_memory_agent.id, actor=actor) + server.agent_manager.delete_agent(agent_id=main_agent.id, actor=actor)