feat: add sleeptime to new agent loop (#2263)

This commit is contained in:
cthomas
2025-05-22 23:22:51 -07:00
committed by GitHub
parent 5049f0a623
commit db520f9a22
3 changed files with 127 additions and 36 deletions

View File

@@ -19,6 +19,8 @@ from letta.services.group_manager import GroupManager
from letta.services.job_manager import JobManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.step_manager import NoopStepManager, StepManager
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
class SleeptimeMultiAgentV2(BaseAgent):
@@ -32,6 +34,8 @@ class SleeptimeMultiAgentV2(BaseAgent):
group_manager: GroupManager,
job_manager: JobManager,
actor: User,
step_manager: StepManager = NoopStepManager(),
telemetry_manager: TelemetryManager = NoopTelemetryManager(),
group: Optional[Group] = None,
):
super().__init__(
@@ -45,11 +49,18 @@ class SleeptimeMultiAgentV2(BaseAgent):
self.passage_manager = passage_manager
self.group_manager = group_manager
self.job_manager = job_manager
self.step_manager = step_manager
self.telemetry_manager = telemetry_manager
# Group settings
assert group.manager_type == ManagerType.sleeptime, f"Expected group manager type to be 'sleeptime', got {group.manager_type}"
self.group = group
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
async def step(
self,
input_messages: List[MessageCreate],
max_steps: int = 10,
use_assistant_message: bool = True,
) -> LettaResponse:
run_ids = []
# Prepare new messages
@@ -68,22 +79,26 @@ class SleeptimeMultiAgentV2(BaseAgent):
block_manager=self.block_manager,
passage_manager=self.passage_manager,
actor=self.actor,
step_manager=self.step_manager,
telemetry_manager=self.telemetry_manager,
)
# Perform foreground agent step
response = await foreground_agent.step(input_messages=new_messages, max_steps=max_steps)
response = await foreground_agent.step(
input_messages=new_messages, max_steps=max_steps, use_assistant_message=use_assistant_message
)
# Get last response messages
last_response_messages = foreground_agent.response_messages
# Update turns counter
if self.group.sleeptime_agent_frequency is not None and self.group.sleeptime_agent_frequency > 0:
turns_counter = self.group_manager.bump_turns_counter(group_id=self.group.id, actor=self.actor)
turns_counter = await self.group_manager.bump_turns_counter_async(group_id=self.group.id, actor=self.actor)
# Perform participant steps
if self.group.sleeptime_agent_frequency is None or (
turns_counter is not None and turns_counter % self.group.sleeptime_agent_frequency == 0
):
last_processed_message_id = self.group_manager.get_last_processed_message_id_and_update(
last_processed_message_id = await self.group_manager.get_last_processed_message_id_and_update_async(
group_id=self.group.id, last_processed_message_id=last_response_messages[-1].id, actor=self.actor
)
for participant_agent_id in self.group.agent_ids:
@@ -92,6 +107,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
participant_agent_id,
last_response_messages,
last_processed_message_id,
use_assistant_message,
)
run_ids.append(run_id)
@@ -103,7 +119,13 @@ class SleeptimeMultiAgentV2(BaseAgent):
response.usage.run_ids = run_ids
return response
async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = 10) -> AsyncGenerator[str, None]:
async def step_stream(
self,
input_messages: List[MessageCreate],
max_steps: int = 10,
use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None,
) -> AsyncGenerator[str, None]:
# Prepare new messages
new_messages = []
for message in input_messages:
@@ -120,9 +142,16 @@ class SleeptimeMultiAgentV2(BaseAgent):
block_manager=self.block_manager,
passage_manager=self.passage_manager,
actor=self.actor,
step_manager=self.step_manager,
telemetry_manager=self.telemetry_manager,
)
# Perform foreground agent step
async for chunk in foreground_agent.step_stream(input_messages=new_messages, max_steps=max_steps):
async for chunk in foreground_agent.step_stream(
input_messages=new_messages,
max_steps=max_steps,
use_assistant_message=use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
):
yield chunk
# Get response messages
@@ -130,20 +159,21 @@ class SleeptimeMultiAgentV2(BaseAgent):
# Update turns counter
if self.group.sleeptime_agent_frequency is not None and self.group.sleeptime_agent_frequency > 0:
turns_counter = self.group_manager.bump_turns_counter(group_id=self.group.id, actor=self.actor)
turns_counter = await self.group_manager.bump_turns_counter_async(group_id=self.group.id, actor=self.actor)
# Perform participant steps
if self.group.sleeptime_agent_frequency is None or (
turns_counter is not None and turns_counter % self.group.sleeptime_agent_frequency == 0
):
last_processed_message_id = self.group_manager.get_last_processed_message_id_and_update(
last_processed_message_id = await self.group_manager.get_last_processed_message_id_and_update_async(
group_id=self.group.id, last_processed_message_id=last_response_messages[-1].id, actor=self.actor
)
for sleeptime_agent_id in self.group.agent_ids:
self._issue_background_task(
run_id = await self._issue_background_task(
sleeptime_agent_id,
last_response_messages,
last_processed_message_id,
use_assistant_message,
)
async def _issue_background_task(
@@ -151,6 +181,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
sleeptime_agent_id: str,
response_messages: List[Message],
last_processed_message_id: str,
use_assistant_message: bool = True,
) -> str:
run = Run(
user_id=self.actor.id,
@@ -160,7 +191,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
"agent_id": sleeptime_agent_id,
},
)
run = self.job_manager.create_job(pydantic_job=run, actor=self.actor)
run = await self.job_manager.create_job_async(pydantic_job=run, actor=self.actor)
asyncio.create_task(
self._participant_agent_step(
@@ -169,6 +200,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
response_messages=response_messages,
last_processed_message_id=last_processed_message_id,
run_id=run.id,
use_assistant_message=True,
)
)
return run.id
@@ -180,11 +212,12 @@ class SleeptimeMultiAgentV2(BaseAgent):
response_messages: List[Message],
last_processed_message_id: str,
run_id: str,
use_assistant_message: bool = True,
) -> str:
try:
# Update job status
job_update = JobUpdate(status=JobStatus.running)
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor)
await self.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=self.actor)
# Create conversation transcript
prior_messages = []
@@ -221,11 +254,14 @@ class SleeptimeMultiAgentV2(BaseAgent):
block_manager=self.block_manager,
passage_manager=self.passage_manager,
actor=self.actor,
step_manager=self.step_manager,
telemetry_manager=self.telemetry_manager,
)
# Perform sleeptime agent step
result = await sleeptime_agent.step(
input_messages=sleeptime_agent_messages,
use_assistant_message=use_assistant_message,
)
# Update job status
@@ -237,7 +273,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
"agent_id": sleeptime_agent_id,
},
)
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor)
await self.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=self.actor)
return result
except Exception as e:
job_update = JobUpdate(
@@ -245,5 +281,5 @@ class SleeptimeMultiAgentV2(BaseAgent):
completed_at=datetime.now(timezone.utc).replace(tzinfo=None),
metadata={"error": str(e)},
)
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor)
await self.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=self.actor)
raise

View File

@@ -13,10 +13,11 @@ from starlette.responses import Response, StreamingResponse
from letta.agents.letta_agent import LettaAgent
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
from letta.schemas.block import Block, BlockUpdate
from letta.schemas.group import Group
from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
@@ -637,22 +638,35 @@ async def send_message(
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
# TODO: This is redundant, remove soon
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor)
agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent
agent_eligible = agent.enable_sleeptime or not agent.multi_agent_group
experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false"
feature_enabled = settings.use_experimental or experimental_header.lower() == "true"
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"]
if agent_eligible and feature_enabled and model_compatible:
experimental_agent = LettaAgent(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
)
if agent.enable_sleeptime:
experimental_agent = SleeptimeMultiAgentV2(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
group_manager=server.group_manager,
job_manager=server.job_manager,
actor=actor,
group=agent.multi_agent_group,
)
else:
experimental_agent = LettaAgent(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
)
result = await experimental_agent.step(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message)
else:
@@ -699,23 +713,38 @@ async def send_message_streaming(
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
# TODO: This is redundant, remove soon
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor)
agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent
agent_eligible = agent.enable_sleeptime or not agent.multi_agent_group
experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false"
feature_enabled = settings.use_experimental or experimental_header.lower() == "true"
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"]
model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai"]
if agent_eligible and feature_enabled and model_compatible:
experimental_agent = LettaAgent(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
)
if agent.enable_sleeptime:
experimental_agent = SleeptimeMultiAgentV2(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
group_manager=server.group_manager,
job_manager=server.job_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
group=agent.multi_agent_group,
)
else:
experimental_agent = LettaAgent(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
)
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
if request.stream_tokens and model_compatible_token_streaming:

View File

@@ -232,6 +232,17 @@ class GroupManager:
return group.turns_counter
@trace_method
@enforce_types
async def bump_turns_counter_async(self, group_id: str, actor: PydanticUser) -> int:
async with db_registry.async_session() as session:
# Ensure group is loadable by user
group = await GroupModel.read_async(session, identifier=group_id, actor=actor)
# Update turns counter
group.turns_counter = (group.turns_counter + 1) % group.sleeptime_agent_frequency
await group.update_async(session, actor=actor)
return group.turns_counter
@enforce_types
def get_last_processed_message_id_and_update(self, group_id: str, last_processed_message_id: str, actor: PydanticUser) -> str:
with db_registry.session() as session:
@@ -246,6 +257,21 @@ class GroupManager:
return prev_last_processed_message_id
@trace_method
@enforce_types
async def get_last_processed_message_id_and_update_async(
self, group_id: str, last_processed_message_id: str, actor: PydanticUser
) -> str:
async with db_registry.async_session() as session:
# Ensure group is loadable by user
group = await GroupModel.read_async(session, identifier=group_id, actor=actor)
# Update last processed message id
prev_last_processed_message_id = group.last_processed_message_id
group.last_processed_message_id = last_processed_message_id
await group.update_async(session, actor=actor)
return prev_last_processed_message_id
@enforce_types
def size(
self,