diff --git a/letta/groups/sleeptime_multi_agent_v2.py b/letta/groups/sleeptime_multi_agent_v2.py index 9cd2cede..f082ca38 100644 --- a/letta/groups/sleeptime_multi_agent_v2.py +++ b/letta/groups/sleeptime_multi_agent_v2.py @@ -19,6 +19,8 @@ from letta.services.group_manager import GroupManager from letta.services.job_manager import JobManager from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager +from letta.services.step_manager import NoopStepManager, StepManager +from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager class SleeptimeMultiAgentV2(BaseAgent): @@ -32,6 +34,8 @@ class SleeptimeMultiAgentV2(BaseAgent): group_manager: GroupManager, job_manager: JobManager, actor: User, + step_manager: StepManager = NoopStepManager(), + telemetry_manager: TelemetryManager = NoopTelemetryManager(), group: Optional[Group] = None, ): super().__init__( @@ -45,11 +49,18 @@ class SleeptimeMultiAgentV2(BaseAgent): self.passage_manager = passage_manager self.group_manager = group_manager self.job_manager = job_manager + self.step_manager = step_manager + self.telemetry_manager = telemetry_manager # Group settings assert group.manager_type == ManagerType.sleeptime, f"Expected group manager type to be 'sleeptime', got {group.manager_type}" self.group = group - async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse: + async def step( + self, + input_messages: List[MessageCreate], + max_steps: int = 10, + use_assistant_message: bool = True, + ) -> LettaResponse: run_ids = [] # Prepare new messages @@ -68,22 +79,26 @@ class SleeptimeMultiAgentV2(BaseAgent): block_manager=self.block_manager, passage_manager=self.passage_manager, actor=self.actor, + step_manager=self.step_manager, + telemetry_manager=self.telemetry_manager, ) # Perform foreground agent step - response = await foreground_agent.step(input_messages=new_messages, max_steps=max_steps) + response = await foreground_agent.step( + input_messages=new_messages, max_steps=max_steps, use_assistant_message=use_assistant_message + ) # Get last response messages last_response_messages = foreground_agent.response_messages # Update turns counter if self.group.sleeptime_agent_frequency is not None and self.group.sleeptime_agent_frequency > 0: - turns_counter = self.group_manager.bump_turns_counter(group_id=self.group.id, actor=self.actor) + turns_counter = await self.group_manager.bump_turns_counter_async(group_id=self.group.id, actor=self.actor) # Perform participant steps if self.group.sleeptime_agent_frequency is None or ( turns_counter is not None and turns_counter % self.group.sleeptime_agent_frequency == 0 ): - last_processed_message_id = self.group_manager.get_last_processed_message_id_and_update( + last_processed_message_id = await self.group_manager.get_last_processed_message_id_and_update_async( group_id=self.group.id, last_processed_message_id=last_response_messages[-1].id, actor=self.actor ) for participant_agent_id in self.group.agent_ids: @@ -92,6 +107,7 @@ class SleeptimeMultiAgentV2(BaseAgent): participant_agent_id, last_response_messages, last_processed_message_id, + use_assistant_message, ) run_ids.append(run_id) @@ -103,7 +119,13 @@ class SleeptimeMultiAgentV2(BaseAgent): response.usage.run_ids = run_ids return response - async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = 10) -> AsyncGenerator[str, None]: + async def step_stream( + self, + input_messages: List[MessageCreate], + max_steps: int = 10, + use_assistant_message: bool = True, + request_start_timestamp_ns: Optional[int] = None, + ) -> AsyncGenerator[str, None]: # Prepare new messages new_messages = [] for message in input_messages: @@ -120,9 +142,16 @@ class SleeptimeMultiAgentV2(BaseAgent): block_manager=self.block_manager, passage_manager=self.passage_manager, actor=self.actor, + step_manager=self.step_manager, + telemetry_manager=self.telemetry_manager, ) # Perform foreground agent step - async for chunk in foreground_agent.step_stream(input_messages=new_messages, max_steps=max_steps): + async for chunk in foreground_agent.step_stream( + input_messages=new_messages, + max_steps=max_steps, + use_assistant_message=use_assistant_message, + request_start_timestamp_ns=request_start_timestamp_ns, + ): yield chunk # Get response messages @@ -130,20 +159,21 @@ class SleeptimeMultiAgentV2(BaseAgent): # Update turns counter if self.group.sleeptime_agent_frequency is not None and self.group.sleeptime_agent_frequency > 0: - turns_counter = self.group_manager.bump_turns_counter(group_id=self.group.id, actor=self.actor) + turns_counter = await self.group_manager.bump_turns_counter_async(group_id=self.group.id, actor=self.actor) # Perform participant steps if self.group.sleeptime_agent_frequency is None or ( turns_counter is not None and turns_counter % self.group.sleeptime_agent_frequency == 0 ): - last_processed_message_id = self.group_manager.get_last_processed_message_id_and_update( + last_processed_message_id = await self.group_manager.get_last_processed_message_id_and_update_async( group_id=self.group.id, last_processed_message_id=last_response_messages[-1].id, actor=self.actor ) for sleeptime_agent_id in self.group.agent_ids: - self._issue_background_task( + run_id = await self._issue_background_task( sleeptime_agent_id, last_response_messages, last_processed_message_id, + use_assistant_message, ) async def _issue_background_task( @@ -151,6 +181,7 @@ class SleeptimeMultiAgentV2(BaseAgent): sleeptime_agent_id: str, response_messages: List[Message], last_processed_message_id: str, + use_assistant_message: bool = True, ) -> str: run = Run( user_id=self.actor.id, @@ -160,7 +191,7 @@ class SleeptimeMultiAgentV2(BaseAgent): "agent_id": sleeptime_agent_id, }, ) - run = self.job_manager.create_job(pydantic_job=run, actor=self.actor) + run = await self.job_manager.create_job_async(pydantic_job=run, actor=self.actor) asyncio.create_task( self._participant_agent_step( @@ -169,6 +200,7 @@ class SleeptimeMultiAgentV2(BaseAgent): response_messages=response_messages, last_processed_message_id=last_processed_message_id, run_id=run.id, + use_assistant_message=True, ) ) return run.id @@ -180,11 +212,12 @@ class SleeptimeMultiAgentV2(BaseAgent): response_messages: List[Message], last_processed_message_id: str, run_id: str, + use_assistant_message: bool = True, ) -> str: try: # Update job status job_update = JobUpdate(status=JobStatus.running) - self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor) + await self.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=self.actor) # Create conversation transcript prior_messages = [] @@ -221,11 +254,14 @@ class SleeptimeMultiAgentV2(BaseAgent): block_manager=self.block_manager, passage_manager=self.passage_manager, actor=self.actor, + step_manager=self.step_manager, + telemetry_manager=self.telemetry_manager, ) # Perform sleeptime agent step result = await sleeptime_agent.step( input_messages=sleeptime_agent_messages, + use_assistant_message=use_assistant_message, ) # Update job status @@ -237,7 +273,7 @@ class SleeptimeMultiAgentV2(BaseAgent): "agent_id": sleeptime_agent_id, }, ) - self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor) + await self.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=self.actor) return result except Exception as e: job_update = JobUpdate( @@ -245,5 +281,5 @@ class SleeptimeMultiAgentV2(BaseAgent): completed_at=datetime.now(timezone.utc).replace(tzinfo=None), metadata={"error": str(e)}, ) - self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor) + await self.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=self.actor) raise diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index c042de03..dbff26ff 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -13,10 +13,11 @@ from starlette.responses import Response, StreamingResponse from letta.agents.letta_agent import LettaAgent from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG +from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2 from letta.helpers.datetime_helpers import get_utc_timestamp_ns from letta.log import get_logger from letta.orm.errors import NoResultFound -from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent +from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent from letta.schemas.block import Block, BlockUpdate from letta.schemas.group import Group from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig @@ -637,22 +638,35 @@ async def send_message( actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) # TODO: This is redundant, remove soon agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor) - agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent + agent_eligible = agent.enable_sleeptime or not agent.multi_agent_group experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false" feature_enabled = settings.use_experimental or experimental_header.lower() == "true" model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"] if agent_eligible and feature_enabled and model_compatible: - experimental_agent = LettaAgent( - agent_id=agent_id, - message_manager=server.message_manager, - agent_manager=server.agent_manager, - block_manager=server.block_manager, - passage_manager=server.passage_manager, - actor=actor, - step_manager=server.step_manager, - telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), - ) + if agent.enable_sleeptime: + experimental_agent = SleeptimeMultiAgentV2( + agent_id=agent_id, + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + passage_manager=server.passage_manager, + group_manager=server.group_manager, + job_manager=server.job_manager, + actor=actor, + group=agent.multi_agent_group, + ) + else: + experimental_agent = LettaAgent( + agent_id=agent_id, + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + passage_manager=server.passage_manager, + actor=actor, + step_manager=server.step_manager, + telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), + ) result = await experimental_agent.step(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message) else: @@ -699,23 +713,38 @@ async def send_message_streaming( actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) # TODO: This is redundant, remove soon agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor) - agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent + agent_eligible = agent.enable_sleeptime or not agent.multi_agent_group experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false" feature_enabled = settings.use_experimental or experimental_header.lower() == "true" model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"] model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai"] if agent_eligible and feature_enabled and model_compatible: - experimental_agent = LettaAgent( - agent_id=agent_id, - message_manager=server.message_manager, - agent_manager=server.agent_manager, - block_manager=server.block_manager, - passage_manager=server.passage_manager, - actor=actor, - step_manager=server.step_manager, - telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), - ) + if agent.enable_sleeptime: + experimental_agent = SleeptimeMultiAgentV2( + agent_id=agent_id, + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + passage_manager=server.passage_manager, + group_manager=server.group_manager, + job_manager=server.job_manager, + actor=actor, + step_manager=server.step_manager, + telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), + group=agent.multi_agent_group, + ) + else: + experimental_agent = LettaAgent( + agent_id=agent_id, + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + passage_manager=server.passage_manager, + actor=actor, + step_manager=server.step_manager, + telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), + ) from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode if request.stream_tokens and model_compatible_token_streaming: diff --git a/letta/services/group_manager.py b/letta/services/group_manager.py index f8a27c42..4bce5825 100644 --- a/letta/services/group_manager.py +++ b/letta/services/group_manager.py @@ -232,6 +232,17 @@ class GroupManager: return group.turns_counter @trace_method + @enforce_types + async def bump_turns_counter_async(self, group_id: str, actor: PydanticUser) -> int: + async with db_registry.async_session() as session: + # Ensure group is loadable by user + group = await GroupModel.read_async(session, identifier=group_id, actor=actor) + + # Update turns counter + group.turns_counter = (group.turns_counter + 1) % group.sleeptime_agent_frequency + await group.update_async(session, actor=actor) + return group.turns_counter + @enforce_types def get_last_processed_message_id_and_update(self, group_id: str, last_processed_message_id: str, actor: PydanticUser) -> str: with db_registry.session() as session: @@ -246,6 +257,21 @@ class GroupManager: return prev_last_processed_message_id @trace_method + @enforce_types + async def get_last_processed_message_id_and_update_async( + self, group_id: str, last_processed_message_id: str, actor: PydanticUser + ) -> str: + async with db_registry.async_session() as session: + # Ensure group is loadable by user + group = await GroupModel.read_async(session, identifier=group_id, actor=actor) + + # Update last processed message id + prev_last_processed_message_id = group.last_processed_message_id + group.last_processed_message_id = last_processed_message_id + await group.update_async(session, actor=actor) + + return prev_last_processed_message_id + @enforce_types def size( self,