From 18e11f53afb02599341a1d6e7f9073f7a64702ff Mon Sep 17 00:00:00 2001 From: cthomas Date: Mon, 4 Aug 2025 11:02:00 -0700 Subject: [PATCH] feat: groups support for agentfile (#3721) --- ...4e860718e0d_add_archival_memory_sharing.py | 33 +++---- letta/schemas/agent_file.py | 19 +++- letta/schemas/group.py | 31 ++++++- letta/services/agent_serialization_manager.py | 51 ++++++++++- tests/test_agent_serialization_v2.py | 87 ++++++++++++++++++- 5 files changed, 197 insertions(+), 24 deletions(-) diff --git a/alembic/versions/74e860718e0d_add_archival_memory_sharing.py b/alembic/versions/74e860718e0d_add_archival_memory_sharing.py index 727cb059..a63e95ee 100644 --- a/alembic/versions/74e860718e0d_add_archival_memory_sharing.py +++ b/alembic/versions/74e860718e0d_add_archival_memory_sharing.py @@ -246,14 +246,14 @@ def upgrade() -> None: start_time = time.time() batch_size = 1000 - processed_agents = 0 # process agents one by one to maintain proper relationships offset = 0 while offset < total_agents: # Get batch of agents that need archives batch_result = connection.execute( - sa.text(""" + sa.text( + """ SELECT DISTINCT a.id, a.name, a.organization_id FROM agent_passages ap JOIN agents a ON ap.agent_id = a.id @@ -264,22 +264,24 @@ def upgrade() -> None: ) ORDER BY a.id LIMIT :batch_size - """).bindparams(batch_size=batch_size) + """ + ).bindparams(batch_size=batch_size) ) - + agents_batch = batch_result.fetchall() if not agents_batch: break # No more agents to process - + batch_count = len(agents_batch) print(f"Processing batch of {batch_count} agents (offset: {offset})...") - + # Create archive and relationship for each agent for agent_id, agent_name, org_id in agents_batch: try: # Create archive archive_result = connection.execute( - sa.text(""" + sa.text( + """ INSERT INTO archives (id, name, description, organization_id, created_at) VALUES ( 'archive-' || gen_random_uuid(), @@ -289,26 +291,25 @@ def upgrade() -> None: NOW() ) RETURNING id - """).bindparams( - archive_name=f"{agent_name or f'Agent {agent_id}'}'s Archive", - org_id=org_id - ) + """ + ).bindparams(archive_name=f"{agent_name or f'Agent {agent_id}'}'s Archive", org_id=org_id) ) archive_id = archive_result.scalar() - + # Create agent-archive relationship connection.execute( - sa.text(""" + sa.text( + """ INSERT INTO archives_agents (agent_id, archive_id, is_owner, created_at) VALUES (:agent_id, :archive_id, TRUE, NOW()) - """).bindparams(agent_id=agent_id, archive_id=archive_id) + """ + ).bindparams(agent_id=agent_id, archive_id=archive_id) ) except Exception as e: print(f"Warning: Failed to create archive for agent {agent_id}: {e}") # Continue with other agents - + offset += batch_count - processed_agents = offset print("Archive creation completed. Starting archive_id updates...") diff --git a/letta/schemas/agent_file.py b/letta/schemas/agent_file.py index c0e4fa47..d00615dd 100644 --- a/letta/schemas/agent_file.py +++ b/letta/schemas/agent_file.py @@ -7,7 +7,7 @@ from letta.schemas.agent import AgentState, CreateAgent from letta.schemas.block import Block, CreateBlock from letta.schemas.enums import MessageRole from letta.schemas.file import FileAgent, FileAgentBase, FileMetadata, FileMetadataBase -from letta.schemas.group import GroupCreate +from letta.schemas.group import Group, GroupCreate from letta.schemas.mcp import MCPServer from letta.schemas.message import Message, MessageCreate from letta.schemas.source import Source, SourceCreate @@ -99,6 +99,7 @@ class AgentSchema(CreateAgent): ) messages: List[MessageSchema] = Field(default_factory=list, description="List of messages in the agent's conversation history") files_agents: List[FileAgentSchema] = Field(default_factory=list, description="List of file-agent relationships for this agent") + group_ids: List[str] = Field(default_factory=list, description="List of groups that the agent manages") @classmethod async def from_agent_state( @@ -163,6 +164,7 @@ class AgentSchema(CreateAgent): in_context_message_ids=agent_state.message_ids or [], messages=message_schemas, # Messages will be populated separately by the manager files_agents=[FileAgentSchema.from_file_agent(f) for f in files_agents], + group_ids=[agent_state.multi_agent_group.id] if agent_state.multi_agent_group else [], **create_agent.model_dump(), ) @@ -173,6 +175,21 @@ class GroupSchema(GroupCreate): __id_prefix__ = "group" id: str = Field(..., description="Human-readable identifier for this group in the file") + @classmethod + def from_group(cls, group: Group) -> "GroupSchema": + """Convert Group to GroupSchema""" + + create_group = GroupCreate( + agent_ids=group.agent_ids, + description=group.description, + manager_config=group.manager_config, + project_id=group.project_id, + shared_block_ids=group.shared_block_ids, + ) + + # Create GroupSchema with the group's ID (will be remapped later) + return cls(id=group.id, **create_group.model_dump()) + class BlockSchema(CreateBlock): """Block with human-readable ID for agent file""" diff --git a/letta/schemas/group.py b/letta/schemas/group.py index fdbfff6d..eb6c6fd8 100644 --- a/letta/schemas/group.py +++ b/letta/schemas/group.py @@ -15,6 +15,10 @@ class ManagerType(str, Enum): swarm = "swarm" +class ManagerConfig(BaseModel): + manager_type: ManagerType = Field(..., description="") + + class GroupBase(LettaBase): __id_prefix__ = "group" @@ -42,9 +46,30 @@ class Group(GroupBase): description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.", ) - -class ManagerConfig(BaseModel): - manager_type: ManagerType = Field(..., description="") + @property + def manager_config(self) -> ManagerConfig: + match self.manager_type: + case ManagerType.round_robin: + return RoundRobinManager(max_turns=self.max_turns) + case ManagerType.supervisor: + return SupervisorManager(manager_agent_id=self.manager_agent_id) + case ManagerType.dynamic: + return DynamicManager( + manager_agent_id=self.manager_agent_id, + termination_token=self.termination_token, + max_turns=self.max_turns, + ) + case ManagerType.sleeptime: + return SleeptimeManager( + manager_agent_id=self.manager_agent_id, + sleeptime_agent_frequency=self.sleeptime_agent_frequency, + ) + case ManagerType.voice_sleeptime: + return VoiceSleeptimeManager( + manager_agent_id=self.manager_agent_id, + max_message_buffer_length=self.max_message_buffer_length, + min_message_buffer_length=self.min_message_buffer_length, + ) class RoundRobinManager(ManagerConfig): diff --git a/letta/services/agent_serialization_manager.py b/letta/services/agent_serialization_manager.py index 3cb1ca02..7c8d4de7 100644 --- a/letta/services/agent_serialization_manager.py +++ b/letta/services/agent_serialization_manager.py @@ -22,6 +22,7 @@ from letta.schemas.agent_file import ( from letta.schemas.block import Block from letta.schemas.enums import FileProcessingStatus from letta.schemas.file import FileMetadata +from letta.schemas.group import Group, GroupCreate from letta.schemas.mcp import MCPServer from letta.schemas.message import Message from letta.schemas.source import Source @@ -230,6 +231,9 @@ class AgentSerializationManager: file_agent.source_id = self._map_db_to_file_id(file_agent.source_id, SourceSchema.__id_prefix__) file_agent.agent_id = agent_file_id + if agent_schema.group_ids: + agent_schema.group_ids = [self._map_db_to_file_id(group_id, GroupSchema.__id_prefix__) for group_id in agent_schema.group_ids] + return agent_schema def _convert_tool_to_schema(self, tool) -> ToolSchema: @@ -308,6 +312,24 @@ class AgentSerializationManager: logger.error(f"Failed to convert MCP server {mcp_server.id}: {e}") raise + def _convert_group_to_schema(self, group: Group) -> GroupSchema: + """Convert Group to GroupSchema with ID remapping""" + try: + group_file_id = self._map_db_to_file_id(group.id, GroupSchema.__id_prefix__, allow_new=False) + group_schema = GroupSchema.from_group(group) + group_schema.id = group_file_id + group_schema.agent_ids = [ + self._map_db_to_file_id(agent_id, AgentSchema.__id_prefix__, allow_new=False) for agent_id in group_schema.agent_ids + ] + if hasattr(group_schema.manager_config, "manager_agent_id"): + group_schema.manager_config.manager_agent_id = self._map_db_to_file_id( + group_schema.manager_config.manager_agent_id, AgentSchema.__id_prefix__, allow_new=False + ) + return group_schema + except Exception as e: + logger.error(f"Failed to convert group {group.id}: {e}") + raise + async def export(self, agent_ids: List[str], actor: User) -> AgentFileSchema: """ Export agents and their related entities to AgentFileSchema format. @@ -332,6 +354,23 @@ class AgentSerializationManager: missing_ids = [agent_id for agent_id in agent_ids if agent_id not in found_ids] raise AgentFileExportError(f"The following agent IDs were not found: {missing_ids}") + groups = [] + group_agent_ids = [] + for agent_state in agent_states: + if agent_state.multi_agent_group != None: + groups.append(agent_state.multi_agent_group) + group_agent_ids.extend(agent_state.multi_agent_group.agent_ids) + + group_agent_ids = list(set(group_agent_ids) - set(agent_ids)) + if group_agent_ids: + group_agent_states = await self.agent_manager.get_agents_by_ids_async(agent_ids=group_agent_ids, actor=actor) + if len(group_agent_states) != len(group_agent_ids): + found_ids = {agent.id for agent in group_agent_states} + missing_ids = [agent_id for agent_id in group_agent_ids if agent_id not in found_ids] + raise AgentFileExportError(f"The following agent IDs were not found: {missing_ids}") + agent_ids.extend(group_agent_ids) + agent_states.extend(group_agent_states) + # cache for file-agent relationships to avoid duplicate queries files_agents_cache = {} # Maps agent_id to list of file_agent relationships @@ -359,13 +398,14 @@ class AgentSerializationManager: source_schemas = [self._convert_source_to_schema(source) for source in source_set] file_schemas = [self._convert_file_to_schema(file_metadata) for file_metadata in file_set] mcp_server_schemas = [self._convert_mcp_server_to_schema(mcp_server) for mcp_server in mcp_server_set] + group_schemas = [self._convert_group_to_schema(group) for group in groups] logger.info(f"Exporting {len(agent_ids)} agents to agent file format") # Return AgentFileSchema with converted entities return AgentFileSchema( agents=agent_schemas, - groups=[], # TODO: Extract and convert groups + groups=group_schemas, blocks=block_schemas, files=file_schemas, sources=source_schemas, @@ -607,6 +647,15 @@ class AgentSerializationManager: ) imported_count += len(files_for_agent) + for group in schema.groups: + group_data = group.model_dump(exclude={"id"}) + group_data["agent_ids"] = [file_to_db_ids[agent_id] for agent_id in group_data["agent_ids"]] + if "manager_agent_id" in group_data["manager_config"]: + group_data["manager_config"]["manager_agent_id"] = file_to_db_ids[group_data["manager_config"]["manager_agent_id"]] + created_group = await self.group_manager.create_group_async(GroupCreate(**group_data), actor) + file_to_db_ids[group.id] = created_group.id + imported_count += 1 + return ImportResult( success=True, message=f"Import completed successfully. Imported {imported_count} entities.", diff --git a/tests/test_agent_serialization_v2.py b/tests/test_agent_serialization_v2.py index 9e06e93f..48dc7229 100644 --- a/tests/test_agent_serialization_v2.py +++ b/tests/test_agent_serialization_v2.py @@ -21,6 +21,7 @@ from letta.schemas.agent_file import ( from letta.schemas.block import Block, CreateBlock from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole +from letta.schemas.group import ManagerType from letta.schemas.llm_config import LLMConfig from letta.schemas.message import MessageCreate from letta.schemas.organization import Organization @@ -615,13 +616,20 @@ def _compare_groups(orig: GroupSchema, imp: GroupSchema, index: int) -> List[str """Compare two GroupSchema objects for logical equivalence.""" errors = [] - if orig.name != imp.name: - errors.append(f"Group {index}: name mismatch: '{orig.name}' vs '{imp.name}'") + orig_agent_ids = sorted(orig.agent_ids) + imp_agent_ids = sorted(imp.agent_ids) + if orig_agent_ids != imp_agent_ids: + errors.append(f"Group {index}: agent_ids mismatch: '{orig_agent_ids}' vs '{imp_agent_ids}'") if orig.description != imp.description: errors.append(f"Group {index}: description mismatch") - if orig.metadata != imp.metadata: + if orig.manager_config != imp.manager_config: + errors.append(f"Group {index}: manager config mismatch") + + orig_shared_block_ids = sorted(orig.shared_block_ids) + imp_shared_block_ids = sorted(imp.shared_block_ids) + if orig_shared_block_ids != imp_shared_block_ids: errors.append(f"Group {index}: metadata mismatch") return errors @@ -1001,6 +1009,43 @@ class TestAgentFileExport: assert len(agent_file.metadata.get("revision_id")) == 12 assert all(c in "0123456789abcdef" for c in agent_file.metadata.get("revision_id")) + async def test_export_sleeptime_enabled_agent(self, server, agent_serialization_manager, default_user, weather_tool): + """Test exporting sleeptime enabled agent.""" + create_agent_request = CreateAgent( + name="sleeptime-enabled-test-agent", + system="Sleeptime enabled test agent", + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + tool_ids=[weather_tool.id], + initial_message_sequence=[ + MessageCreate(role=MessageRole.user, content="Second agent message"), + ], + enable_sleeptime=True, + ) + + sleeptime_enabled_agent = await server.create_agent_async( + request=create_agent_request, + actor=default_user, + ) + + agent_file = await agent_serialization_manager.export([sleeptime_enabled_agent.id], default_user) + + assert sleeptime_enabled_agent.multi_agent_group != None + assert len(agent_file.agents) == 2 + assert validate_id_format(agent_file) + + agent_ids = {agent.id for agent in agent_file.agents} + assert len(agent_ids) == 2 + + assert len(agent_file.groups) == 1 + sleeptime_group = agent_file.groups[0] + assert len(sleeptime_group.agent_ids) == 1 + assert sleeptime_group.agent_ids[0] in agent_ids + assert sleeptime_group.manager_config.manager_type == ManagerType.sleeptime + assert sleeptime_group.manager_config.manager_agent_id in agent_ids + + await server.agent_manager.delete_agent_async(agent_id=sleeptime_enabled_agent.id, actor=default_user) + class TestAgentFileImport: """Tests for agent file import functionality.""" @@ -1098,6 +1143,42 @@ class TestAgentFileImport: with pytest.raises(AgentFileImportError): await agent_serialization_manager.import_file(invalid_agent_file, other_user) + async def test_import_sleeptime_enabled_agent(self, server, agent_serialization_manager, default_user, other_user, weather_tool): + """Test basic agent import functionality.""" + create_agent_request = CreateAgent( + name="sleeptime-enabled-test-agent", + system="Sleeptime enabled test agent", + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + tool_ids=[weather_tool.id], + initial_message_sequence=[ + MessageCreate(role=MessageRole.user, content="Second agent message"), + ], + enable_sleeptime=True, + ) + + sleeptime_enabled_agent = await server.create_agent_async( + request=create_agent_request, + actor=default_user, + ) + + sleeptime_enabled_agent.multi_agent_group.id + sleeptime_enabled_agent.multi_agent_group.agent_ids[0] + + agent_file = await agent_serialization_manager.export([sleeptime_enabled_agent.id], default_user) + + result = await agent_serialization_manager.import_file(agent_file, other_user) + assert result.success + assert result.imported_count > 0 + assert len(result.id_mappings) > 0 + + exported_agent_ids = [file_id for file_id in list(result.id_mappings.values()) if file_id.startswith("agent-")] + assert len(exported_agent_ids) == 2 + exported_group_ids = [file_id for file_id in list(result.id_mappings.keys()) if file_id.startswith("group-")] + assert len(exported_group_ids) == 1 + + await server.agent_manager.delete_agent_async(agent_id=sleeptime_enabled_agent.id, actor=default_user) + class TestAgentFileImportWithProcessing: """Tests for agent file import with file processing (chunking/embedding)."""