feat: add sleeptime to new agent loop (#1900)
This commit is contained in:
@@ -57,6 +57,7 @@ class LettaAgent(BaseAgent):
|
||||
self.block_manager = block_manager
|
||||
self.passage_manager = passage_manager
|
||||
self.use_assistant_message = use_assistant_message
|
||||
self.response_messages: List[Message] = []
|
||||
|
||||
@trace_method
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
|
||||
@@ -81,6 +82,7 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
persisted_messages, should_continue = await self._handle_ai_response(tool_call, agent_state, tool_rules_solver)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
|
||||
if not should_continue:
|
||||
@@ -139,6 +141,7 @@ class LettaAgent(BaseAgent):
|
||||
pre_computed_assistant_message_id=interface.letta_assistant_message_id,
|
||||
pre_computed_tool_message_id=interface.letta_tool_message_id,
|
||||
)
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
|
||||
if not should_continue:
|
||||
@@ -167,7 +170,14 @@ class LettaAgent(BaseAgent):
|
||||
tools = [
|
||||
t
|
||||
for t in agent_state.tools
|
||||
if t.tool_type in {ToolType.CUSTOM, ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE}
|
||||
if t.tool_type
|
||||
in {
|
||||
ToolType.CUSTOM,
|
||||
ToolType.LETTA_CORE,
|
||||
ToolType.LETTA_MEMORY_CORE,
|
||||
ToolType.LETTA_MULTI_AGENT_CORE,
|
||||
ToolType.LETTA_SLEEPTIME_CORE,
|
||||
}
|
||||
or (t.tool_type == ToolType.LETTA_MULTI_AGENT_CORE and t.name == "send_message_to_agents_matching_tags")
|
||||
]
|
||||
|
||||
|
||||
@@ -88,11 +88,14 @@ def load_multi_agent(
|
||||
def stringify_message(message: Message, use_assistant_name: bool = False) -> str | None:
|
||||
assistant_name = message.name or "assistant" if use_assistant_name else "assistant"
|
||||
if message.role == "user":
|
||||
content = json.loads(message.content[0].text)
|
||||
if content["type"] == "user_message":
|
||||
return f"{message.name or 'user'}: {content['message']}"
|
||||
else:
|
||||
return None
|
||||
try:
|
||||
content = json.loads(message.content[0].text)
|
||||
if content["type"] == "user_message":
|
||||
return f"{message.name or 'user'}: {content['message']}"
|
||||
else:
|
||||
return None
|
||||
except:
|
||||
return f"{message.name or 'user'}: {message.content[0].text}"
|
||||
elif message.role == "assistant":
|
||||
messages = []
|
||||
if message.tool_calls:
|
||||
|
||||
@@ -107,6 +107,9 @@ class SleeptimeMultiAgent(Agent):
|
||||
run_id: str,
|
||||
) -> LettaUsageStatistics:
|
||||
try:
|
||||
job_update = JobUpdate(status=JobStatus.running)
|
||||
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.user)
|
||||
|
||||
participant_agent_state = self.agent_manager.get_agent_by_id(participant_agent_id, actor=self.user)
|
||||
participant_agent = Agent(
|
||||
agent_state=participant_agent_state,
|
||||
|
||||
249
letta/groups/sleeptime_multi_agent_v2.py
Normal file
249
letta/groups/sleeptime_multi_agent_v2.py
Normal file
@@ -0,0 +1,249 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.groups.helpers import stringify_message
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.group import Group, ManagerType
|
||||
from letta.schemas.job import JobUpdate
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.run import Run
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.group_manager import GroupManager
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
|
||||
|
||||
class SleeptimeMultiAgentV2(BaseAgent):
|
||||
def __init__(
|
||||
self,
|
||||
agent_id: str,
|
||||
message_manager: MessageManager,
|
||||
agent_manager: AgentManager,
|
||||
block_manager: BlockManager,
|
||||
passage_manager: PassageManager,
|
||||
group_manager: GroupManager,
|
||||
job_manager: JobManager,
|
||||
actor: User,
|
||||
group: Optional[Group] = None,
|
||||
):
|
||||
super().__init__(
|
||||
agent_id=agent_id,
|
||||
openai_client=None,
|
||||
message_manager=message_manager,
|
||||
agent_manager=agent_manager,
|
||||
actor=actor,
|
||||
)
|
||||
self.block_manager = block_manager
|
||||
self.passage_manager = passage_manager
|
||||
self.group_manager = group_manager
|
||||
self.job_manager = job_manager
|
||||
# Group settings
|
||||
assert group.manager_type == ManagerType.sleeptime, f"Expected group manager type to be 'sleeptime', got {group.manager_type}"
|
||||
self.group = group
|
||||
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
|
||||
run_ids = []
|
||||
|
||||
# Prepare new messages
|
||||
new_messages = []
|
||||
for message in input_messages:
|
||||
if isinstance(message.content, str):
|
||||
message.content = [TextContent(text=message.content)]
|
||||
message.group_id = self.group.id
|
||||
new_messages.append(message)
|
||||
|
||||
# Load foreground agent
|
||||
foreground_agent = LettaAgent(
|
||||
agent_id=self.agent_id,
|
||||
message_manager=self.message_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
block_manager=self.block_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=self.actor,
|
||||
)
|
||||
# Perform foreground agent step
|
||||
response = await foreground_agent.step(input_messages=new_messages, max_steps=max_steps)
|
||||
|
||||
# Get last response messages
|
||||
last_response_messages = foreground_agent.response_messages
|
||||
|
||||
# Update turns counter
|
||||
if self.group.sleeptime_agent_frequency is not None and self.group.sleeptime_agent_frequency > 0:
|
||||
turns_counter = self.group_manager.bump_turns_counter(group_id=self.group.id, actor=self.actor)
|
||||
|
||||
# Perform participant steps
|
||||
if self.group.sleeptime_agent_frequency is None or (
|
||||
turns_counter is not None and turns_counter % self.group.sleeptime_agent_frequency == 0
|
||||
):
|
||||
last_processed_message_id = self.group_manager.get_last_processed_message_id_and_update(
|
||||
group_id=self.group.id, last_processed_message_id=last_response_messages[-1].id, actor=self.actor
|
||||
)
|
||||
for participant_agent_id in self.group.agent_ids:
|
||||
try:
|
||||
run_id = await self._issue_background_task(
|
||||
participant_agent_id,
|
||||
last_response_messages,
|
||||
last_processed_message_id,
|
||||
)
|
||||
run_ids.append(run_id)
|
||||
|
||||
except Exception as e:
|
||||
# Individual task failures
|
||||
print(f"Agent processing failed: {str(e)}")
|
||||
raise e
|
||||
|
||||
response.usage.run_ids = run_ids
|
||||
return response
|
||||
|
||||
async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = 10) -> AsyncGenerator[str, None]:
|
||||
# Prepare new messages
|
||||
new_messages = []
|
||||
for message in input_messages:
|
||||
if isinstance(message.content, str):
|
||||
message.content = [TextContent(text=message.content)]
|
||||
message.group_id = self.group.id
|
||||
new_messages.append(message)
|
||||
|
||||
# Load foreground agent
|
||||
foreground_agent = LettaAgent(
|
||||
agent_id=self.agent_id,
|
||||
message_manager=self.message_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
block_manager=self.block_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=self.actor,
|
||||
)
|
||||
# Perform foreground agent step
|
||||
async for chunk in foreground_agent.step_stream(input_messages=new_messages, max_steps=max_steps):
|
||||
yield chunk
|
||||
|
||||
# Get response messages
|
||||
last_response_messages = foreground_agent.response_messages
|
||||
|
||||
# Update turns counter
|
||||
if self.group.sleeptime_agent_frequency is not None and self.group.sleeptime_agent_frequency > 0:
|
||||
turns_counter = self.group_manager.bump_turns_counter(group_id=self.group.id, actor=self.actor)
|
||||
|
||||
# Perform participant steps
|
||||
if self.group.sleeptime_agent_frequency is None or (
|
||||
turns_counter is not None and turns_counter % self.group.sleeptime_agent_frequency == 0
|
||||
):
|
||||
last_processed_message_id = self.group_manager.get_last_processed_message_id_and_update(
|
||||
group_id=self.group.id, last_processed_message_id=last_response_messages[-1].id, actor=self.actor
|
||||
)
|
||||
for sleeptime_agent_id in self.group.agent_ids:
|
||||
self._issue_background_task(
|
||||
sleeptime_agent_id,
|
||||
last_response_messages,
|
||||
last_processed_message_id,
|
||||
)
|
||||
|
||||
async def _issue_background_task(
|
||||
self,
|
||||
sleeptime_agent_id: str,
|
||||
response_messages: List[Message],
|
||||
last_processed_message_id: str,
|
||||
) -> str:
|
||||
run = Run(
|
||||
user_id=self.actor.id,
|
||||
status=JobStatus.created,
|
||||
metadata={
|
||||
"job_type": "sleeptime_agent_send_message_async", # is this right?
|
||||
"agent_id": sleeptime_agent_id,
|
||||
},
|
||||
)
|
||||
run = self.job_manager.create_job(pydantic_job=run, actor=self.actor)
|
||||
|
||||
asyncio.create_task(
|
||||
self._participant_agent_step(
|
||||
foreground_agent_id=self.agent_id,
|
||||
sleeptime_agent_id=sleeptime_agent_id,
|
||||
response_messages=response_messages,
|
||||
last_processed_message_id=last_processed_message_id,
|
||||
run_id=run.id,
|
||||
)
|
||||
)
|
||||
return run.id
|
||||
|
||||
async def _participant_agent_step(
|
||||
self,
|
||||
foreground_agent_id: str,
|
||||
sleeptime_agent_id: str,
|
||||
response_messages: List[Message],
|
||||
last_processed_message_id: str,
|
||||
run_id: str,
|
||||
) -> str:
|
||||
try:
|
||||
# Update job status
|
||||
job_update = JobUpdate(status=JobStatus.running)
|
||||
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor)
|
||||
|
||||
# Create conversation transcript
|
||||
prior_messages = []
|
||||
if self.group.sleeptime_agent_frequency:
|
||||
try:
|
||||
prior_messages = self.message_manager.list_messages_for_agent(
|
||||
agent_id=foreground_agent_id,
|
||||
actor=self.actor,
|
||||
after=last_processed_message_id,
|
||||
before=response_messages[0].id,
|
||||
)
|
||||
except Exception:
|
||||
pass # continue with just latest messages
|
||||
|
||||
transcript_summary = [stringify_message(message) for message in prior_messages + response_messages]
|
||||
transcript_summary = [summary for summary in transcript_summary if summary is not None]
|
||||
message_text = "\n".join(transcript_summary)
|
||||
|
||||
sleeptime_agent_messages = [
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=[TextContent(text=message_text)],
|
||||
id=Message.generate_id(),
|
||||
agent_id=sleeptime_agent_id,
|
||||
group_id=self.group.id,
|
||||
)
|
||||
]
|
||||
|
||||
# Load sleeptime agent
|
||||
sleeptime_agent = LettaAgent(
|
||||
agent_id=sleeptime_agent_id,
|
||||
message_manager=self.message_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
block_manager=self.block_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=self.actor,
|
||||
)
|
||||
|
||||
# Perform sleeptime agent step
|
||||
result = await sleeptime_agent.step(
|
||||
input_messages=sleeptime_agent_messages,
|
||||
)
|
||||
|
||||
# Update job status
|
||||
job_update = JobUpdate(
|
||||
status=JobStatus.completed,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
metadata={
|
||||
"result": result.model_dump(mode="json"),
|
||||
"agent_id": sleeptime_agent_id,
|
||||
},
|
||||
)
|
||||
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor)
|
||||
return result
|
||||
except Exception as e:
|
||||
job_update = JobUpdate(
|
||||
status=JobStatus.failed,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
metadata={"error": str(e)},
|
||||
)
|
||||
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor)
|
||||
raise
|
||||
@@ -192,7 +192,7 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
AgentManager().rebuild_system_prompt(agent_id=agent_state.id, actor=actor, force=True)
|
||||
return None
|
||||
|
||||
def core_memory_append(self, agent_state: "AgentState", actor: User, label: str, content: str) -> Optional[str]:
|
||||
def core_memory_append(self, agent_state: AgentState, actor: User, label: str, content: str) -> Optional[str]:
|
||||
"""
|
||||
Append to the contents of core memory.
|
||||
|
||||
@@ -211,7 +211,7 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
|
||||
def core_memory_replace(
|
||||
self,
|
||||
agent_state: "AgentState",
|
||||
agent_state: AgentState,
|
||||
actor: User,
|
||||
label: str,
|
||||
old_content: str,
|
||||
@@ -237,7 +237,8 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
return None
|
||||
|
||||
def memory_replace(
|
||||
agent_state: "AgentState",
|
||||
self,
|
||||
agent_state: AgentState,
|
||||
actor: User,
|
||||
label: str,
|
||||
old_str: str,
|
||||
@@ -326,7 +327,8 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
return success_msg
|
||||
|
||||
def memory_insert(
|
||||
agent_state: "AgentState",
|
||||
self,
|
||||
agent_state: AgentState,
|
||||
actor: User,
|
||||
label: str,
|
||||
new_str: str,
|
||||
@@ -407,7 +409,7 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
|
||||
return success_msg
|
||||
|
||||
def memory_rethink(agent_state: "AgentState", actor: User, label: str, new_memory: str) -> str:
|
||||
def memory_rethink(self, agent_state: AgentState, actor: User, label: str, new_memory: str) -> str:
|
||||
"""
|
||||
The memory_rethink command allows you to completely rewrite the contents of a
|
||||
memory block. Use this tool to make large sweeping changes (e.g. when you want
|
||||
@@ -458,7 +460,7 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
# return None
|
||||
return success_msg
|
||||
|
||||
def memory_finish_edits(agent_state: "AgentState") -> None:
|
||||
def memory_finish_edits(self, agent_state: AgentState, actor: User) -> None:
|
||||
"""
|
||||
Call the memory_finish_edits command when you are finished making edits
|
||||
(integrating all new information) into the memory blocks. This function
|
||||
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy import delete
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import DEFAULT_HUMAN
|
||||
from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2
|
||||
from letta.orm import Provider, Step
|
||||
from letta.orm.enums import JobType
|
||||
from letta.orm.errors import NoResultFound
|
||||
@@ -152,10 +153,132 @@ async def test_sleeptime_group_chat(server, actor):
|
||||
assert len(agent_runs) == len(run_ids)
|
||||
|
||||
# 6. Verify run status after sleep
|
||||
time.sleep(10)
|
||||
time.sleep(2)
|
||||
for run_id in run_ids:
|
||||
job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor)
|
||||
assert job.status == JobStatus.completed
|
||||
assert job.status == JobStatus.running or job.status == JobStatus.completed
|
||||
|
||||
# 7. Delete agent
|
||||
server.agent_manager.delete_agent(agent_id=main_agent.id, actor=actor)
|
||||
|
||||
with pytest.raises(NoResultFound):
|
||||
server.group_manager.retrieve_group(group_id=group.id, actor=actor)
|
||||
with pytest.raises(NoResultFound):
|
||||
server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sleeptime_group_chat_v2(server, actor):
|
||||
# 0. Refresh base tools
|
||||
server.tool_manager.upsert_base_tools(actor=actor)
|
||||
|
||||
# 1. Create sleeptime agent
|
||||
main_agent = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="main_agent",
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="You are a personal assistant that helps users with requests.",
|
||||
),
|
||||
CreateBlock(
|
||||
label="human",
|
||||
value="My favorite plant is the fiddle leaf\nMy favorite color is lavender",
|
||||
),
|
||||
],
|
||||
# model="openai/gpt-4o-mini",
|
||||
model="anthropic/claude-3-5-sonnet-20240620",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
enable_sleeptime=True,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
assert main_agent.enable_sleeptime == True
|
||||
main_agent_tools = [tool.name for tool in main_agent.tools]
|
||||
assert "core_memory_append" not in main_agent_tools
|
||||
assert "core_memory_replace" not in main_agent_tools
|
||||
assert "archival_memory_insert" not in main_agent_tools
|
||||
|
||||
# 2. Override frequency for test
|
||||
group = server.group_manager.modify_group(
|
||||
group_id=main_agent.multi_agent_group.id,
|
||||
group_update=GroupUpdate(
|
||||
manager_config=SleeptimeManagerUpdate(
|
||||
sleeptime_agent_frequency=2,
|
||||
),
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
assert group.manager_type == ManagerType.sleeptime
|
||||
assert group.sleeptime_agent_frequency == 2
|
||||
assert len(group.agent_ids) == 1
|
||||
|
||||
# 3. Verify shared blocks
|
||||
sleeptime_agent_id = group.agent_ids[0]
|
||||
shared_block = server.agent_manager.get_block_with_label(agent_id=main_agent.id, block_label="human", actor=actor)
|
||||
agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor)
|
||||
assert len(agents) == 2
|
||||
assert sleeptime_agent_id in [agent.id for agent in agents]
|
||||
assert main_agent.id in [agent.id for agent in agents]
|
||||
|
||||
# 4 Verify sleeptime agent tools
|
||||
sleeptime_agent = server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor)
|
||||
sleeptime_agent_tools = [tool.name for tool in sleeptime_agent.tools]
|
||||
assert "memory_rethink" in sleeptime_agent_tools
|
||||
assert "memory_finish_edits" in sleeptime_agent_tools
|
||||
assert "memory_replace" in sleeptime_agent_tools
|
||||
assert "memory_insert" in sleeptime_agent_tools
|
||||
|
||||
assert len([rule for rule in sleeptime_agent.tool_rules if rule.type == ToolRuleType.exit_loop]) > 0
|
||||
|
||||
# 5. Send messages and verify run ids
|
||||
message_text = [
|
||||
"my favorite color is orange",
|
||||
"not particularly. today is a good day",
|
||||
"actually my favorite color is coral",
|
||||
"let's change the subject",
|
||||
"actually my fav plant is the the african spear",
|
||||
"indeed",
|
||||
]
|
||||
run_ids = []
|
||||
for i, text in enumerate(message_text):
|
||||
agent = SleeptimeMultiAgentV2(
|
||||
agent_id=main_agent.id,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
group_manager=server.group_manager,
|
||||
job_manager=server.job_manager,
|
||||
actor=actor,
|
||||
group=main_agent.multi_agent_group,
|
||||
)
|
||||
|
||||
response = await agent.step(
|
||||
input_messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=text,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
assert len(response.messages) > 0
|
||||
assert len(response.usage.run_ids or []) == (i + 1) % 2
|
||||
run_ids.extend(response.usage.run_ids or [])
|
||||
|
||||
jobs = server.job_manager.list_jobs(actor=actor, job_type=JobType.RUN)
|
||||
runs = [Run.from_job(job) for job in jobs]
|
||||
agent_runs = [run for run in runs if "agent_id" in run.metadata and run.metadata["agent_id"] == sleeptime_agent_id]
|
||||
assert len(agent_runs) == len(run_ids)
|
||||
|
||||
# 6. Verify run status after sleep
|
||||
time.sleep(2)
|
||||
for run_id in run_ids:
|
||||
job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor)
|
||||
assert job.status == JobStatus.running or job.status == JobStatus.completed
|
||||
|
||||
# 7. Delete agent
|
||||
server.agent_manager.delete_agent(agent_id=main_agent.id, actor=actor)
|
||||
|
||||
Reference in New Issue
Block a user