feat: groups support for agentfile (#3721)

This commit is contained in:
cthomas
2025-08-04 11:02:00 -07:00
committed by GitHub
parent 2f76ece603
commit 18e11f53af
5 changed files with 197 additions and 24 deletions

View File

@@ -246,14 +246,14 @@ def upgrade() -> None:
start_time = time.time()
batch_size = 1000
processed_agents = 0
# process agents one by one to maintain proper relationships
offset = 0
while offset < total_agents:
# Get batch of agents that need archives
batch_result = connection.execute(
sa.text("""
sa.text(
"""
SELECT DISTINCT a.id, a.name, a.organization_id
FROM agent_passages ap
JOIN agents a ON ap.agent_id = a.id
@@ -264,22 +264,24 @@ def upgrade() -> None:
)
ORDER BY a.id
LIMIT :batch_size
""").bindparams(batch_size=batch_size)
"""
).bindparams(batch_size=batch_size)
)
agents_batch = batch_result.fetchall()
if not agents_batch:
break # No more agents to process
batch_count = len(agents_batch)
print(f"Processing batch of {batch_count} agents (offset: {offset})...")
# Create archive and relationship for each agent
for agent_id, agent_name, org_id in agents_batch:
try:
# Create archive
archive_result = connection.execute(
sa.text("""
sa.text(
"""
INSERT INTO archives (id, name, description, organization_id, created_at)
VALUES (
'archive-' || gen_random_uuid(),
@@ -289,26 +291,25 @@ def upgrade() -> None:
NOW()
)
RETURNING id
""").bindparams(
archive_name=f"{agent_name or f'Agent {agent_id}'}'s Archive",
org_id=org_id
)
"""
).bindparams(archive_name=f"{agent_name or f'Agent {agent_id}'}'s Archive", org_id=org_id)
)
archive_id = archive_result.scalar()
# Create agent-archive relationship
connection.execute(
sa.text("""
sa.text(
"""
INSERT INTO archives_agents (agent_id, archive_id, is_owner, created_at)
VALUES (:agent_id, :archive_id, TRUE, NOW())
""").bindparams(agent_id=agent_id, archive_id=archive_id)
"""
).bindparams(agent_id=agent_id, archive_id=archive_id)
)
except Exception as e:
print(f"Warning: Failed to create archive for agent {agent_id}: {e}")
# Continue with other agents
offset += batch_count
processed_agents = offset
print("Archive creation completed. Starting archive_id updates...")

View File

@@ -7,7 +7,7 @@ from letta.schemas.agent import AgentState, CreateAgent
from letta.schemas.block import Block, CreateBlock
from letta.schemas.enums import MessageRole
from letta.schemas.file import FileAgent, FileAgentBase, FileMetadata, FileMetadataBase
from letta.schemas.group import GroupCreate
from letta.schemas.group import Group, GroupCreate
from letta.schemas.mcp import MCPServer
from letta.schemas.message import Message, MessageCreate
from letta.schemas.source import Source, SourceCreate
@@ -99,6 +99,7 @@ class AgentSchema(CreateAgent):
)
messages: List[MessageSchema] = Field(default_factory=list, description="List of messages in the agent's conversation history")
files_agents: List[FileAgentSchema] = Field(default_factory=list, description="List of file-agent relationships for this agent")
group_ids: List[str] = Field(default_factory=list, description="List of groups that the agent manages")
@classmethod
async def from_agent_state(
@@ -163,6 +164,7 @@ class AgentSchema(CreateAgent):
in_context_message_ids=agent_state.message_ids or [],
messages=message_schemas, # Messages will be populated separately by the manager
files_agents=[FileAgentSchema.from_file_agent(f) for f in files_agents],
group_ids=[agent_state.multi_agent_group.id] if agent_state.multi_agent_group else [],
**create_agent.model_dump(),
)
@@ -173,6 +175,21 @@ class GroupSchema(GroupCreate):
__id_prefix__ = "group"
id: str = Field(..., description="Human-readable identifier for this group in the file")
@classmethod
def from_group(cls, group: Group) -> "GroupSchema":
"""Convert Group to GroupSchema"""
create_group = GroupCreate(
agent_ids=group.agent_ids,
description=group.description,
manager_config=group.manager_config,
project_id=group.project_id,
shared_block_ids=group.shared_block_ids,
)
# Create GroupSchema with the group's ID (will be remapped later)
return cls(id=group.id, **create_group.model_dump())
class BlockSchema(CreateBlock):
"""Block with human-readable ID for agent file"""

View File

@@ -15,6 +15,10 @@ class ManagerType(str, Enum):
swarm = "swarm"
class ManagerConfig(BaseModel):
manager_type: ManagerType = Field(..., description="")
class GroupBase(LettaBase):
__id_prefix__ = "group"
@@ -42,9 +46,30 @@ class Group(GroupBase):
description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
)
class ManagerConfig(BaseModel):
manager_type: ManagerType = Field(..., description="")
@property
def manager_config(self) -> ManagerConfig:
match self.manager_type:
case ManagerType.round_robin:
return RoundRobinManager(max_turns=self.max_turns)
case ManagerType.supervisor:
return SupervisorManager(manager_agent_id=self.manager_agent_id)
case ManagerType.dynamic:
return DynamicManager(
manager_agent_id=self.manager_agent_id,
termination_token=self.termination_token,
max_turns=self.max_turns,
)
case ManagerType.sleeptime:
return SleeptimeManager(
manager_agent_id=self.manager_agent_id,
sleeptime_agent_frequency=self.sleeptime_agent_frequency,
)
case ManagerType.voice_sleeptime:
return VoiceSleeptimeManager(
manager_agent_id=self.manager_agent_id,
max_message_buffer_length=self.max_message_buffer_length,
min_message_buffer_length=self.min_message_buffer_length,
)
class RoundRobinManager(ManagerConfig):

View File

@@ -22,6 +22,7 @@ from letta.schemas.agent_file import (
from letta.schemas.block import Block
from letta.schemas.enums import FileProcessingStatus
from letta.schemas.file import FileMetadata
from letta.schemas.group import Group, GroupCreate
from letta.schemas.mcp import MCPServer
from letta.schemas.message import Message
from letta.schemas.source import Source
@@ -230,6 +231,9 @@ class AgentSerializationManager:
file_agent.source_id = self._map_db_to_file_id(file_agent.source_id, SourceSchema.__id_prefix__)
file_agent.agent_id = agent_file_id
if agent_schema.group_ids:
agent_schema.group_ids = [self._map_db_to_file_id(group_id, GroupSchema.__id_prefix__) for group_id in agent_schema.group_ids]
return agent_schema
def _convert_tool_to_schema(self, tool) -> ToolSchema:
@@ -308,6 +312,24 @@ class AgentSerializationManager:
logger.error(f"Failed to convert MCP server {mcp_server.id}: {e}")
raise
def _convert_group_to_schema(self, group: Group) -> GroupSchema:
"""Convert Group to GroupSchema with ID remapping"""
try:
group_file_id = self._map_db_to_file_id(group.id, GroupSchema.__id_prefix__, allow_new=False)
group_schema = GroupSchema.from_group(group)
group_schema.id = group_file_id
group_schema.agent_ids = [
self._map_db_to_file_id(agent_id, AgentSchema.__id_prefix__, allow_new=False) for agent_id in group_schema.agent_ids
]
if hasattr(group_schema.manager_config, "manager_agent_id"):
group_schema.manager_config.manager_agent_id = self._map_db_to_file_id(
group_schema.manager_config.manager_agent_id, AgentSchema.__id_prefix__, allow_new=False
)
return group_schema
except Exception as e:
logger.error(f"Failed to convert group {group.id}: {e}")
raise
async def export(self, agent_ids: List[str], actor: User) -> AgentFileSchema:
"""
Export agents and their related entities to AgentFileSchema format.
@@ -332,6 +354,23 @@ class AgentSerializationManager:
missing_ids = [agent_id for agent_id in agent_ids if agent_id not in found_ids]
raise AgentFileExportError(f"The following agent IDs were not found: {missing_ids}")
groups = []
group_agent_ids = []
for agent_state in agent_states:
if agent_state.multi_agent_group != None:
groups.append(agent_state.multi_agent_group)
group_agent_ids.extend(agent_state.multi_agent_group.agent_ids)
group_agent_ids = list(set(group_agent_ids) - set(agent_ids))
if group_agent_ids:
group_agent_states = await self.agent_manager.get_agents_by_ids_async(agent_ids=group_agent_ids, actor=actor)
if len(group_agent_states) != len(group_agent_ids):
found_ids = {agent.id for agent in group_agent_states}
missing_ids = [agent_id for agent_id in group_agent_ids if agent_id not in found_ids]
raise AgentFileExportError(f"The following agent IDs were not found: {missing_ids}")
agent_ids.extend(group_agent_ids)
agent_states.extend(group_agent_states)
# cache for file-agent relationships to avoid duplicate queries
files_agents_cache = {} # Maps agent_id to list of file_agent relationships
@@ -359,13 +398,14 @@ class AgentSerializationManager:
source_schemas = [self._convert_source_to_schema(source) for source in source_set]
file_schemas = [self._convert_file_to_schema(file_metadata) for file_metadata in file_set]
mcp_server_schemas = [self._convert_mcp_server_to_schema(mcp_server) for mcp_server in mcp_server_set]
group_schemas = [self._convert_group_to_schema(group) for group in groups]
logger.info(f"Exporting {len(agent_ids)} agents to agent file format")
# Return AgentFileSchema with converted entities
return AgentFileSchema(
agents=agent_schemas,
groups=[], # TODO: Extract and convert groups
groups=group_schemas,
blocks=block_schemas,
files=file_schemas,
sources=source_schemas,
@@ -607,6 +647,15 @@ class AgentSerializationManager:
)
imported_count += len(files_for_agent)
for group in schema.groups:
group_data = group.model_dump(exclude={"id"})
group_data["agent_ids"] = [file_to_db_ids[agent_id] for agent_id in group_data["agent_ids"]]
if "manager_agent_id" in group_data["manager_config"]:
group_data["manager_config"]["manager_agent_id"] = file_to_db_ids[group_data["manager_config"]["manager_agent_id"]]
created_group = await self.group_manager.create_group_async(GroupCreate(**group_data), actor)
file_to_db_ids[group.id] = created_group.id
imported_count += 1
return ImportResult(
success=True,
message=f"Import completed successfully. Imported {imported_count} entities.",

View File

@@ -21,6 +21,7 @@ from letta.schemas.agent_file import (
from letta.schemas.block import Block, CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
from letta.schemas.group import ManagerType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate
from letta.schemas.organization import Organization
@@ -615,13 +616,20 @@ def _compare_groups(orig: GroupSchema, imp: GroupSchema, index: int) -> List[str
"""Compare two GroupSchema objects for logical equivalence."""
errors = []
if orig.name != imp.name:
errors.append(f"Group {index}: name mismatch: '{orig.name}' vs '{imp.name}'")
orig_agent_ids = sorted(orig.agent_ids)
imp_agent_ids = sorted(imp.agent_ids)
if orig_agent_ids != imp_agent_ids:
errors.append(f"Group {index}: agent_ids mismatch: '{orig_agent_ids}' vs '{imp_agent_ids}'")
if orig.description != imp.description:
errors.append(f"Group {index}: description mismatch")
if orig.metadata != imp.metadata:
if orig.manager_config != imp.manager_config:
errors.append(f"Group {index}: manager config mismatch")
orig_shared_block_ids = sorted(orig.shared_block_ids)
imp_shared_block_ids = sorted(imp.shared_block_ids)
if orig_shared_block_ids != imp_shared_block_ids:
errors.append(f"Group {index}: metadata mismatch")
return errors
@@ -1001,6 +1009,43 @@ class TestAgentFileExport:
assert len(agent_file.metadata.get("revision_id")) == 12
assert all(c in "0123456789abcdef" for c in agent_file.metadata.get("revision_id"))
async def test_export_sleeptime_enabled_agent(self, server, agent_serialization_manager, default_user, weather_tool):
"""Test exporting sleeptime enabled agent."""
create_agent_request = CreateAgent(
name="sleeptime-enabled-test-agent",
system="Sleeptime enabled test agent",
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
tool_ids=[weather_tool.id],
initial_message_sequence=[
MessageCreate(role=MessageRole.user, content="Second agent message"),
],
enable_sleeptime=True,
)
sleeptime_enabled_agent = await server.create_agent_async(
request=create_agent_request,
actor=default_user,
)
agent_file = await agent_serialization_manager.export([sleeptime_enabled_agent.id], default_user)
assert sleeptime_enabled_agent.multi_agent_group != None
assert len(agent_file.agents) == 2
assert validate_id_format(agent_file)
agent_ids = {agent.id for agent in agent_file.agents}
assert len(agent_ids) == 2
assert len(agent_file.groups) == 1
sleeptime_group = agent_file.groups[0]
assert len(sleeptime_group.agent_ids) == 1
assert sleeptime_group.agent_ids[0] in agent_ids
assert sleeptime_group.manager_config.manager_type == ManagerType.sleeptime
assert sleeptime_group.manager_config.manager_agent_id in agent_ids
await server.agent_manager.delete_agent_async(agent_id=sleeptime_enabled_agent.id, actor=default_user)
class TestAgentFileImport:
"""Tests for agent file import functionality."""
@@ -1098,6 +1143,42 @@ class TestAgentFileImport:
with pytest.raises(AgentFileImportError):
await agent_serialization_manager.import_file(invalid_agent_file, other_user)
async def test_import_sleeptime_enabled_agent(self, server, agent_serialization_manager, default_user, other_user, weather_tool):
"""Test basic agent import functionality."""
create_agent_request = CreateAgent(
name="sleeptime-enabled-test-agent",
system="Sleeptime enabled test agent",
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
tool_ids=[weather_tool.id],
initial_message_sequence=[
MessageCreate(role=MessageRole.user, content="Second agent message"),
],
enable_sleeptime=True,
)
sleeptime_enabled_agent = await server.create_agent_async(
request=create_agent_request,
actor=default_user,
)
sleeptime_enabled_agent.multi_agent_group.id
sleeptime_enabled_agent.multi_agent_group.agent_ids[0]
agent_file = await agent_serialization_manager.export([sleeptime_enabled_agent.id], default_user)
result = await agent_serialization_manager.import_file(agent_file, other_user)
assert result.success
assert result.imported_count > 0
assert len(result.id_mappings) > 0
exported_agent_ids = [file_id for file_id in list(result.id_mappings.values()) if file_id.startswith("agent-")]
assert len(exported_agent_ids) == 2
exported_group_ids = [file_id for file_id in list(result.id_mappings.keys()) if file_id.startswith("group-")]
assert len(exported_group_ids) == 1
await server.agent_manager.delete_agent_async(agent_id=sleeptime_enabled_agent.id, actor=default_user)
class TestAgentFileImportWithProcessing:
"""Tests for agent file import with file processing (chunking/embedding)."""