feat: groups support for agentfile (#3721)
This commit is contained in:
@@ -246,14 +246,14 @@ def upgrade() -> None:
|
||||
start_time = time.time()
|
||||
|
||||
batch_size = 1000
|
||||
processed_agents = 0
|
||||
|
||||
# process agents one by one to maintain proper relationships
|
||||
offset = 0
|
||||
while offset < total_agents:
|
||||
# Get batch of agents that need archives
|
||||
batch_result = connection.execute(
|
||||
sa.text("""
|
||||
sa.text(
|
||||
"""
|
||||
SELECT DISTINCT a.id, a.name, a.organization_id
|
||||
FROM agent_passages ap
|
||||
JOIN agents a ON ap.agent_id = a.id
|
||||
@@ -264,22 +264,24 @@ def upgrade() -> None:
|
||||
)
|
||||
ORDER BY a.id
|
||||
LIMIT :batch_size
|
||||
""").bindparams(batch_size=batch_size)
|
||||
"""
|
||||
).bindparams(batch_size=batch_size)
|
||||
)
|
||||
|
||||
|
||||
agents_batch = batch_result.fetchall()
|
||||
if not agents_batch:
|
||||
break # No more agents to process
|
||||
|
||||
|
||||
batch_count = len(agents_batch)
|
||||
print(f"Processing batch of {batch_count} agents (offset: {offset})...")
|
||||
|
||||
|
||||
# Create archive and relationship for each agent
|
||||
for agent_id, agent_name, org_id in agents_batch:
|
||||
try:
|
||||
# Create archive
|
||||
archive_result = connection.execute(
|
||||
sa.text("""
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO archives (id, name, description, organization_id, created_at)
|
||||
VALUES (
|
||||
'archive-' || gen_random_uuid(),
|
||||
@@ -289,26 +291,25 @@ def upgrade() -> None:
|
||||
NOW()
|
||||
)
|
||||
RETURNING id
|
||||
""").bindparams(
|
||||
archive_name=f"{agent_name or f'Agent {agent_id}'}'s Archive",
|
||||
org_id=org_id
|
||||
)
|
||||
"""
|
||||
).bindparams(archive_name=f"{agent_name or f'Agent {agent_id}'}'s Archive", org_id=org_id)
|
||||
)
|
||||
archive_id = archive_result.scalar()
|
||||
|
||||
|
||||
# Create agent-archive relationship
|
||||
connection.execute(
|
||||
sa.text("""
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO archives_agents (agent_id, archive_id, is_owner, created_at)
|
||||
VALUES (:agent_id, :archive_id, TRUE, NOW())
|
||||
""").bindparams(agent_id=agent_id, archive_id=archive_id)
|
||||
"""
|
||||
).bindparams(agent_id=agent_id, archive_id=archive_id)
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to create archive for agent {agent_id}: {e}")
|
||||
# Continue with other agents
|
||||
|
||||
|
||||
offset += batch_count
|
||||
processed_agents = offset
|
||||
|
||||
print("Archive creation completed. Starting archive_id updates...")
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from letta.schemas.agent import AgentState, CreateAgent
|
||||
from letta.schemas.block import Block, CreateBlock
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.file import FileAgent, FileAgentBase, FileMetadata, FileMetadataBase
|
||||
from letta.schemas.group import GroupCreate
|
||||
from letta.schemas.group import Group, GroupCreate
|
||||
from letta.schemas.mcp import MCPServer
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.source import Source, SourceCreate
|
||||
@@ -99,6 +99,7 @@ class AgentSchema(CreateAgent):
|
||||
)
|
||||
messages: List[MessageSchema] = Field(default_factory=list, description="List of messages in the agent's conversation history")
|
||||
files_agents: List[FileAgentSchema] = Field(default_factory=list, description="List of file-agent relationships for this agent")
|
||||
group_ids: List[str] = Field(default_factory=list, description="List of groups that the agent manages")
|
||||
|
||||
@classmethod
|
||||
async def from_agent_state(
|
||||
@@ -163,6 +164,7 @@ class AgentSchema(CreateAgent):
|
||||
in_context_message_ids=agent_state.message_ids or [],
|
||||
messages=message_schemas, # Messages will be populated separately by the manager
|
||||
files_agents=[FileAgentSchema.from_file_agent(f) for f in files_agents],
|
||||
group_ids=[agent_state.multi_agent_group.id] if agent_state.multi_agent_group else [],
|
||||
**create_agent.model_dump(),
|
||||
)
|
||||
|
||||
@@ -173,6 +175,21 @@ class GroupSchema(GroupCreate):
|
||||
__id_prefix__ = "group"
|
||||
id: str = Field(..., description="Human-readable identifier for this group in the file")
|
||||
|
||||
@classmethod
|
||||
def from_group(cls, group: Group) -> "GroupSchema":
|
||||
"""Convert Group to GroupSchema"""
|
||||
|
||||
create_group = GroupCreate(
|
||||
agent_ids=group.agent_ids,
|
||||
description=group.description,
|
||||
manager_config=group.manager_config,
|
||||
project_id=group.project_id,
|
||||
shared_block_ids=group.shared_block_ids,
|
||||
)
|
||||
|
||||
# Create GroupSchema with the group's ID (will be remapped later)
|
||||
return cls(id=group.id, **create_group.model_dump())
|
||||
|
||||
|
||||
class BlockSchema(CreateBlock):
|
||||
"""Block with human-readable ID for agent file"""
|
||||
|
||||
@@ -15,6 +15,10 @@ class ManagerType(str, Enum):
|
||||
swarm = "swarm"
|
||||
|
||||
|
||||
class ManagerConfig(BaseModel):
|
||||
manager_type: ManagerType = Field(..., description="")
|
||||
|
||||
|
||||
class GroupBase(LettaBase):
|
||||
__id_prefix__ = "group"
|
||||
|
||||
@@ -42,9 +46,30 @@ class Group(GroupBase):
|
||||
description="The desired minimum length of messages in the context window of the convo agent. This is a best effort, and may be off-by-one due to user/assistant interleaving.",
|
||||
)
|
||||
|
||||
|
||||
class ManagerConfig(BaseModel):
|
||||
manager_type: ManagerType = Field(..., description="")
|
||||
@property
|
||||
def manager_config(self) -> ManagerConfig:
|
||||
match self.manager_type:
|
||||
case ManagerType.round_robin:
|
||||
return RoundRobinManager(max_turns=self.max_turns)
|
||||
case ManagerType.supervisor:
|
||||
return SupervisorManager(manager_agent_id=self.manager_agent_id)
|
||||
case ManagerType.dynamic:
|
||||
return DynamicManager(
|
||||
manager_agent_id=self.manager_agent_id,
|
||||
termination_token=self.termination_token,
|
||||
max_turns=self.max_turns,
|
||||
)
|
||||
case ManagerType.sleeptime:
|
||||
return SleeptimeManager(
|
||||
manager_agent_id=self.manager_agent_id,
|
||||
sleeptime_agent_frequency=self.sleeptime_agent_frequency,
|
||||
)
|
||||
case ManagerType.voice_sleeptime:
|
||||
return VoiceSleeptimeManager(
|
||||
manager_agent_id=self.manager_agent_id,
|
||||
max_message_buffer_length=self.max_message_buffer_length,
|
||||
min_message_buffer_length=self.min_message_buffer_length,
|
||||
)
|
||||
|
||||
|
||||
class RoundRobinManager(ManagerConfig):
|
||||
|
||||
@@ -22,6 +22,7 @@ from letta.schemas.agent_file import (
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.enums import FileProcessingStatus
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.group import Group, GroupCreate
|
||||
from letta.schemas.mcp import MCPServer
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.source import Source
|
||||
@@ -230,6 +231,9 @@ class AgentSerializationManager:
|
||||
file_agent.source_id = self._map_db_to_file_id(file_agent.source_id, SourceSchema.__id_prefix__)
|
||||
file_agent.agent_id = agent_file_id
|
||||
|
||||
if agent_schema.group_ids:
|
||||
agent_schema.group_ids = [self._map_db_to_file_id(group_id, GroupSchema.__id_prefix__) for group_id in agent_schema.group_ids]
|
||||
|
||||
return agent_schema
|
||||
|
||||
def _convert_tool_to_schema(self, tool) -> ToolSchema:
|
||||
@@ -308,6 +312,24 @@ class AgentSerializationManager:
|
||||
logger.error(f"Failed to convert MCP server {mcp_server.id}: {e}")
|
||||
raise
|
||||
|
||||
def _convert_group_to_schema(self, group: Group) -> GroupSchema:
|
||||
"""Convert Group to GroupSchema with ID remapping"""
|
||||
try:
|
||||
group_file_id = self._map_db_to_file_id(group.id, GroupSchema.__id_prefix__, allow_new=False)
|
||||
group_schema = GroupSchema.from_group(group)
|
||||
group_schema.id = group_file_id
|
||||
group_schema.agent_ids = [
|
||||
self._map_db_to_file_id(agent_id, AgentSchema.__id_prefix__, allow_new=False) for agent_id in group_schema.agent_ids
|
||||
]
|
||||
if hasattr(group_schema.manager_config, "manager_agent_id"):
|
||||
group_schema.manager_config.manager_agent_id = self._map_db_to_file_id(
|
||||
group_schema.manager_config.manager_agent_id, AgentSchema.__id_prefix__, allow_new=False
|
||||
)
|
||||
return group_schema
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert group {group.id}: {e}")
|
||||
raise
|
||||
|
||||
async def export(self, agent_ids: List[str], actor: User) -> AgentFileSchema:
|
||||
"""
|
||||
Export agents and their related entities to AgentFileSchema format.
|
||||
@@ -332,6 +354,23 @@ class AgentSerializationManager:
|
||||
missing_ids = [agent_id for agent_id in agent_ids if agent_id not in found_ids]
|
||||
raise AgentFileExportError(f"The following agent IDs were not found: {missing_ids}")
|
||||
|
||||
groups = []
|
||||
group_agent_ids = []
|
||||
for agent_state in agent_states:
|
||||
if agent_state.multi_agent_group != None:
|
||||
groups.append(agent_state.multi_agent_group)
|
||||
group_agent_ids.extend(agent_state.multi_agent_group.agent_ids)
|
||||
|
||||
group_agent_ids = list(set(group_agent_ids) - set(agent_ids))
|
||||
if group_agent_ids:
|
||||
group_agent_states = await self.agent_manager.get_agents_by_ids_async(agent_ids=group_agent_ids, actor=actor)
|
||||
if len(group_agent_states) != len(group_agent_ids):
|
||||
found_ids = {agent.id for agent in group_agent_states}
|
||||
missing_ids = [agent_id for agent_id in group_agent_ids if agent_id not in found_ids]
|
||||
raise AgentFileExportError(f"The following agent IDs were not found: {missing_ids}")
|
||||
agent_ids.extend(group_agent_ids)
|
||||
agent_states.extend(group_agent_states)
|
||||
|
||||
# cache for file-agent relationships to avoid duplicate queries
|
||||
files_agents_cache = {} # Maps agent_id to list of file_agent relationships
|
||||
|
||||
@@ -359,13 +398,14 @@ class AgentSerializationManager:
|
||||
source_schemas = [self._convert_source_to_schema(source) for source in source_set]
|
||||
file_schemas = [self._convert_file_to_schema(file_metadata) for file_metadata in file_set]
|
||||
mcp_server_schemas = [self._convert_mcp_server_to_schema(mcp_server) for mcp_server in mcp_server_set]
|
||||
group_schemas = [self._convert_group_to_schema(group) for group in groups]
|
||||
|
||||
logger.info(f"Exporting {len(agent_ids)} agents to agent file format")
|
||||
|
||||
# Return AgentFileSchema with converted entities
|
||||
return AgentFileSchema(
|
||||
agents=agent_schemas,
|
||||
groups=[], # TODO: Extract and convert groups
|
||||
groups=group_schemas,
|
||||
blocks=block_schemas,
|
||||
files=file_schemas,
|
||||
sources=source_schemas,
|
||||
@@ -607,6 +647,15 @@ class AgentSerializationManager:
|
||||
)
|
||||
imported_count += len(files_for_agent)
|
||||
|
||||
for group in schema.groups:
|
||||
group_data = group.model_dump(exclude={"id"})
|
||||
group_data["agent_ids"] = [file_to_db_ids[agent_id] for agent_id in group_data["agent_ids"]]
|
||||
if "manager_agent_id" in group_data["manager_config"]:
|
||||
group_data["manager_config"]["manager_agent_id"] = file_to_db_ids[group_data["manager_config"]["manager_agent_id"]]
|
||||
created_group = await self.group_manager.create_group_async(GroupCreate(**group_data), actor)
|
||||
file_to_db_ids[group.id] = created_group.id
|
||||
imported_count += 1
|
||||
|
||||
return ImportResult(
|
||||
success=True,
|
||||
message=f"Import completed successfully. Imported {imported_count} entities.",
|
||||
|
||||
@@ -21,6 +21,7 @@ from letta.schemas.agent_file import (
|
||||
from letta.schemas.block import Block, CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.group import ManagerType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.organization import Organization
|
||||
@@ -615,13 +616,20 @@ def _compare_groups(orig: GroupSchema, imp: GroupSchema, index: int) -> List[str
|
||||
"""Compare two GroupSchema objects for logical equivalence."""
|
||||
errors = []
|
||||
|
||||
if orig.name != imp.name:
|
||||
errors.append(f"Group {index}: name mismatch: '{orig.name}' vs '{imp.name}'")
|
||||
orig_agent_ids = sorted(orig.agent_ids)
|
||||
imp_agent_ids = sorted(imp.agent_ids)
|
||||
if orig_agent_ids != imp_agent_ids:
|
||||
errors.append(f"Group {index}: agent_ids mismatch: '{orig_agent_ids}' vs '{imp_agent_ids}'")
|
||||
|
||||
if orig.description != imp.description:
|
||||
errors.append(f"Group {index}: description mismatch")
|
||||
|
||||
if orig.metadata != imp.metadata:
|
||||
if orig.manager_config != imp.manager_config:
|
||||
errors.append(f"Group {index}: manager config mismatch")
|
||||
|
||||
orig_shared_block_ids = sorted(orig.shared_block_ids)
|
||||
imp_shared_block_ids = sorted(imp.shared_block_ids)
|
||||
if orig_shared_block_ids != imp_shared_block_ids:
|
||||
errors.append(f"Group {index}: metadata mismatch")
|
||||
|
||||
return errors
|
||||
@@ -1001,6 +1009,43 @@ class TestAgentFileExport:
|
||||
assert len(agent_file.metadata.get("revision_id")) == 12
|
||||
assert all(c in "0123456789abcdef" for c in agent_file.metadata.get("revision_id"))
|
||||
|
||||
async def test_export_sleeptime_enabled_agent(self, server, agent_serialization_manager, default_user, weather_tool):
|
||||
"""Test exporting sleeptime enabled agent."""
|
||||
create_agent_request = CreateAgent(
|
||||
name="sleeptime-enabled-test-agent",
|
||||
system="Sleeptime enabled test agent",
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
tool_ids=[weather_tool.id],
|
||||
initial_message_sequence=[
|
||||
MessageCreate(role=MessageRole.user, content="Second agent message"),
|
||||
],
|
||||
enable_sleeptime=True,
|
||||
)
|
||||
|
||||
sleeptime_enabled_agent = await server.create_agent_async(
|
||||
request=create_agent_request,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
agent_file = await agent_serialization_manager.export([sleeptime_enabled_agent.id], default_user)
|
||||
|
||||
assert sleeptime_enabled_agent.multi_agent_group != None
|
||||
assert len(agent_file.agents) == 2
|
||||
assert validate_id_format(agent_file)
|
||||
|
||||
agent_ids = {agent.id for agent in agent_file.agents}
|
||||
assert len(agent_ids) == 2
|
||||
|
||||
assert len(agent_file.groups) == 1
|
||||
sleeptime_group = agent_file.groups[0]
|
||||
assert len(sleeptime_group.agent_ids) == 1
|
||||
assert sleeptime_group.agent_ids[0] in agent_ids
|
||||
assert sleeptime_group.manager_config.manager_type == ManagerType.sleeptime
|
||||
assert sleeptime_group.manager_config.manager_agent_id in agent_ids
|
||||
|
||||
await server.agent_manager.delete_agent_async(agent_id=sleeptime_enabled_agent.id, actor=default_user)
|
||||
|
||||
|
||||
class TestAgentFileImport:
|
||||
"""Tests for agent file import functionality."""
|
||||
@@ -1098,6 +1143,42 @@ class TestAgentFileImport:
|
||||
with pytest.raises(AgentFileImportError):
|
||||
await agent_serialization_manager.import_file(invalid_agent_file, other_user)
|
||||
|
||||
async def test_import_sleeptime_enabled_agent(self, server, agent_serialization_manager, default_user, other_user, weather_tool):
|
||||
"""Test basic agent import functionality."""
|
||||
create_agent_request = CreateAgent(
|
||||
name="sleeptime-enabled-test-agent",
|
||||
system="Sleeptime enabled test agent",
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
tool_ids=[weather_tool.id],
|
||||
initial_message_sequence=[
|
||||
MessageCreate(role=MessageRole.user, content="Second agent message"),
|
||||
],
|
||||
enable_sleeptime=True,
|
||||
)
|
||||
|
||||
sleeptime_enabled_agent = await server.create_agent_async(
|
||||
request=create_agent_request,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
sleeptime_enabled_agent.multi_agent_group.id
|
||||
sleeptime_enabled_agent.multi_agent_group.agent_ids[0]
|
||||
|
||||
agent_file = await agent_serialization_manager.export([sleeptime_enabled_agent.id], default_user)
|
||||
|
||||
result = await agent_serialization_manager.import_file(agent_file, other_user)
|
||||
assert result.success
|
||||
assert result.imported_count > 0
|
||||
assert len(result.id_mappings) > 0
|
||||
|
||||
exported_agent_ids = [file_id for file_id in list(result.id_mappings.values()) if file_id.startswith("agent-")]
|
||||
assert len(exported_agent_ids) == 2
|
||||
exported_group_ids = [file_id for file_id in list(result.id_mappings.keys()) if file_id.startswith("group-")]
|
||||
assert len(exported_group_ids) == 1
|
||||
|
||||
await server.agent_manager.delete_agent_async(agent_id=sleeptime_enabled_agent.id, actor=default_user)
|
||||
|
||||
|
||||
class TestAgentFileImportWithProcessing:
|
||||
"""Tests for agent file import with file processing (chunking/embedding)."""
|
||||
|
||||
Reference in New Issue
Block a user