feat: add sleeptime to new agent loop (#2263)
This commit is contained in:
@@ -19,6 +19,8 @@ from letta.services.group_manager import GroupManager
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.step_manager import NoopStepManager, StepManager
|
||||
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
|
||||
|
||||
|
||||
class SleeptimeMultiAgentV2(BaseAgent):
|
||||
@@ -32,6 +34,8 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
group_manager: GroupManager,
|
||||
job_manager: JobManager,
|
||||
actor: User,
|
||||
step_manager: StepManager = NoopStepManager(),
|
||||
telemetry_manager: TelemetryManager = NoopTelemetryManager(),
|
||||
group: Optional[Group] = None,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -45,11 +49,18 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
self.passage_manager = passage_manager
|
||||
self.group_manager = group_manager
|
||||
self.job_manager = job_manager
|
||||
self.step_manager = step_manager
|
||||
self.telemetry_manager = telemetry_manager
|
||||
# Group settings
|
||||
assert group.manager_type == ManagerType.sleeptime, f"Expected group manager type to be 'sleeptime', got {group.manager_type}"
|
||||
self.group = group
|
||||
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
|
||||
async def step(
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
max_steps: int = 10,
|
||||
use_assistant_message: bool = True,
|
||||
) -> LettaResponse:
|
||||
run_ids = []
|
||||
|
||||
# Prepare new messages
|
||||
@@ -68,22 +79,26 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
block_manager=self.block_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=self.actor,
|
||||
step_manager=self.step_manager,
|
||||
telemetry_manager=self.telemetry_manager,
|
||||
)
|
||||
# Perform foreground agent step
|
||||
response = await foreground_agent.step(input_messages=new_messages, max_steps=max_steps)
|
||||
response = await foreground_agent.step(
|
||||
input_messages=new_messages, max_steps=max_steps, use_assistant_message=use_assistant_message
|
||||
)
|
||||
|
||||
# Get last response messages
|
||||
last_response_messages = foreground_agent.response_messages
|
||||
|
||||
# Update turns counter
|
||||
if self.group.sleeptime_agent_frequency is not None and self.group.sleeptime_agent_frequency > 0:
|
||||
turns_counter = self.group_manager.bump_turns_counter(group_id=self.group.id, actor=self.actor)
|
||||
turns_counter = await self.group_manager.bump_turns_counter_async(group_id=self.group.id, actor=self.actor)
|
||||
|
||||
# Perform participant steps
|
||||
if self.group.sleeptime_agent_frequency is None or (
|
||||
turns_counter is not None and turns_counter % self.group.sleeptime_agent_frequency == 0
|
||||
):
|
||||
last_processed_message_id = self.group_manager.get_last_processed_message_id_and_update(
|
||||
last_processed_message_id = await self.group_manager.get_last_processed_message_id_and_update_async(
|
||||
group_id=self.group.id, last_processed_message_id=last_response_messages[-1].id, actor=self.actor
|
||||
)
|
||||
for participant_agent_id in self.group.agent_ids:
|
||||
@@ -92,6 +107,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
participant_agent_id,
|
||||
last_response_messages,
|
||||
last_processed_message_id,
|
||||
use_assistant_message,
|
||||
)
|
||||
run_ids.append(run_id)
|
||||
|
||||
@@ -103,7 +119,13 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
response.usage.run_ids = run_ids
|
||||
return response
|
||||
|
||||
async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = 10) -> AsyncGenerator[str, None]:
|
||||
async def step_stream(
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
max_steps: int = 10,
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: Optional[int] = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
# Prepare new messages
|
||||
new_messages = []
|
||||
for message in input_messages:
|
||||
@@ -120,9 +142,16 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
block_manager=self.block_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=self.actor,
|
||||
step_manager=self.step_manager,
|
||||
telemetry_manager=self.telemetry_manager,
|
||||
)
|
||||
# Perform foreground agent step
|
||||
async for chunk in foreground_agent.step_stream(input_messages=new_messages, max_steps=max_steps):
|
||||
async for chunk in foreground_agent.step_stream(
|
||||
input_messages=new_messages,
|
||||
max_steps=max_steps,
|
||||
use_assistant_message=use_assistant_message,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
# Get response messages
|
||||
@@ -130,20 +159,21 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
|
||||
# Update turns counter
|
||||
if self.group.sleeptime_agent_frequency is not None and self.group.sleeptime_agent_frequency > 0:
|
||||
turns_counter = self.group_manager.bump_turns_counter(group_id=self.group.id, actor=self.actor)
|
||||
turns_counter = await self.group_manager.bump_turns_counter_async(group_id=self.group.id, actor=self.actor)
|
||||
|
||||
# Perform participant steps
|
||||
if self.group.sleeptime_agent_frequency is None or (
|
||||
turns_counter is not None and turns_counter % self.group.sleeptime_agent_frequency == 0
|
||||
):
|
||||
last_processed_message_id = self.group_manager.get_last_processed_message_id_and_update(
|
||||
last_processed_message_id = await self.group_manager.get_last_processed_message_id_and_update_async(
|
||||
group_id=self.group.id, last_processed_message_id=last_response_messages[-1].id, actor=self.actor
|
||||
)
|
||||
for sleeptime_agent_id in self.group.agent_ids:
|
||||
self._issue_background_task(
|
||||
run_id = await self._issue_background_task(
|
||||
sleeptime_agent_id,
|
||||
last_response_messages,
|
||||
last_processed_message_id,
|
||||
use_assistant_message,
|
||||
)
|
||||
|
||||
async def _issue_background_task(
|
||||
@@ -151,6 +181,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
sleeptime_agent_id: str,
|
||||
response_messages: List[Message],
|
||||
last_processed_message_id: str,
|
||||
use_assistant_message: bool = True,
|
||||
) -> str:
|
||||
run = Run(
|
||||
user_id=self.actor.id,
|
||||
@@ -160,7 +191,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
"agent_id": sleeptime_agent_id,
|
||||
},
|
||||
)
|
||||
run = self.job_manager.create_job(pydantic_job=run, actor=self.actor)
|
||||
run = await self.job_manager.create_job_async(pydantic_job=run, actor=self.actor)
|
||||
|
||||
asyncio.create_task(
|
||||
self._participant_agent_step(
|
||||
@@ -169,6 +200,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
response_messages=response_messages,
|
||||
last_processed_message_id=last_processed_message_id,
|
||||
run_id=run.id,
|
||||
use_assistant_message=True,
|
||||
)
|
||||
)
|
||||
return run.id
|
||||
@@ -180,11 +212,12 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
response_messages: List[Message],
|
||||
last_processed_message_id: str,
|
||||
run_id: str,
|
||||
use_assistant_message: bool = True,
|
||||
) -> str:
|
||||
try:
|
||||
# Update job status
|
||||
job_update = JobUpdate(status=JobStatus.running)
|
||||
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor)
|
||||
await self.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=self.actor)
|
||||
|
||||
# Create conversation transcript
|
||||
prior_messages = []
|
||||
@@ -221,11 +254,14 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
block_manager=self.block_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=self.actor,
|
||||
step_manager=self.step_manager,
|
||||
telemetry_manager=self.telemetry_manager,
|
||||
)
|
||||
|
||||
# Perform sleeptime agent step
|
||||
result = await sleeptime_agent.step(
|
||||
input_messages=sleeptime_agent_messages,
|
||||
use_assistant_message=use_assistant_message,
|
||||
)
|
||||
|
||||
# Update job status
|
||||
@@ -237,7 +273,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
"agent_id": sleeptime_agent_id,
|
||||
},
|
||||
)
|
||||
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor)
|
||||
await self.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=self.actor)
|
||||
return result
|
||||
except Exception as e:
|
||||
job_update = JobUpdate(
|
||||
@@ -245,5 +281,5 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
completed_at=datetime.now(timezone.utc).replace(tzinfo=None),
|
||||
metadata={"error": str(e)},
|
||||
)
|
||||
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor)
|
||||
await self.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=self.actor)
|
||||
raise
|
||||
|
||||
@@ -13,10 +13,11 @@ from starlette.responses import Response, StreamingResponse
|
||||
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2
|
||||
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent
|
||||
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import Block, BlockUpdate
|
||||
from letta.schemas.group import Group
|
||||
from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
|
||||
@@ -637,22 +638,35 @@ async def send_message(
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
# TODO: This is redundant, remove soon
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor)
|
||||
agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent
|
||||
agent_eligible = agent.enable_sleeptime or not agent.multi_agent_group
|
||||
experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false"
|
||||
feature_enabled = settings.use_experimental or experimental_header.lower() == "true"
|
||||
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"]
|
||||
|
||||
if agent_eligible and feature_enabled and model_compatible:
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
)
|
||||
if agent.enable_sleeptime:
|
||||
experimental_agent = SleeptimeMultiAgentV2(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
group_manager=server.group_manager,
|
||||
job_manager=server.job_manager,
|
||||
actor=actor,
|
||||
group=agent.multi_agent_group,
|
||||
)
|
||||
else:
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
)
|
||||
|
||||
result = await experimental_agent.step(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message)
|
||||
else:
|
||||
@@ -699,23 +713,38 @@ async def send_message_streaming(
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
# TODO: This is redundant, remove soon
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor)
|
||||
agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent
|
||||
agent_eligible = agent.enable_sleeptime or not agent.multi_agent_group
|
||||
experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false"
|
||||
feature_enabled = settings.use_experimental or experimental_header.lower() == "true"
|
||||
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"]
|
||||
model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai"]
|
||||
|
||||
if agent_eligible and feature_enabled and model_compatible:
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
)
|
||||
if agent.enable_sleeptime:
|
||||
experimental_agent = SleeptimeMultiAgentV2(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
group_manager=server.group_manager,
|
||||
job_manager=server.job_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
group=agent.multi_agent_group,
|
||||
)
|
||||
else:
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
)
|
||||
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
|
||||
|
||||
if request.stream_tokens and model_compatible_token_streaming:
|
||||
|
||||
@@ -232,6 +232,17 @@ class GroupManager:
|
||||
return group.turns_counter
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
async def bump_turns_counter_async(self, group_id: str, actor: PydanticUser) -> int:
|
||||
async with db_registry.async_session() as session:
|
||||
# Ensure group is loadable by user
|
||||
group = await GroupModel.read_async(session, identifier=group_id, actor=actor)
|
||||
|
||||
# Update turns counter
|
||||
group.turns_counter = (group.turns_counter + 1) % group.sleeptime_agent_frequency
|
||||
await group.update_async(session, actor=actor)
|
||||
return group.turns_counter
|
||||
|
||||
@enforce_types
|
||||
def get_last_processed_message_id_and_update(self, group_id: str, last_processed_message_id: str, actor: PydanticUser) -> str:
|
||||
with db_registry.session() as session:
|
||||
@@ -246,6 +257,21 @@ class GroupManager:
|
||||
return prev_last_processed_message_id
|
||||
|
||||
@trace_method
|
||||
@enforce_types
|
||||
async def get_last_processed_message_id_and_update_async(
|
||||
self, group_id: str, last_processed_message_id: str, actor: PydanticUser
|
||||
) -> str:
|
||||
async with db_registry.async_session() as session:
|
||||
# Ensure group is loadable by user
|
||||
group = await GroupModel.read_async(session, identifier=group_id, actor=actor)
|
||||
|
||||
# Update last processed message id
|
||||
prev_last_processed_message_id = group.last_processed_message_id
|
||||
group.last_processed_message_id = last_processed_message_id
|
||||
await group.update_async(session, actor=actor)
|
||||
|
||||
return prev_last_processed_message_id
|
||||
|
||||
@enforce_types
|
||||
def size(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user