diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 35aad811..834098f9 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -57,6 +57,7 @@ class LettaAgent(BaseAgent): self.block_manager = block_manager self.passage_manager = passage_manager self.use_assistant_message = use_assistant_message + self.response_messages: List[Message] = [] @trace_method async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse: @@ -81,6 +82,7 @@ class LettaAgent(BaseAgent): tool_call = response.choices[0].message.tool_calls[0] persisted_messages, should_continue = await self._handle_ai_response(tool_call, agent_state, tool_rules_solver) + self.response_messages.extend(persisted_messages) new_in_context_messages.extend(persisted_messages) if not should_continue: @@ -139,6 +141,7 @@ class LettaAgent(BaseAgent): pre_computed_assistant_message_id=interface.letta_assistant_message_id, pre_computed_tool_message_id=interface.letta_tool_message_id, ) + self.response_messages.extend(persisted_messages) new_in_context_messages.extend(persisted_messages) if not should_continue: @@ -167,7 +170,14 @@ class LettaAgent(BaseAgent): tools = [ t for t in agent_state.tools - if t.tool_type in {ToolType.CUSTOM, ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE} + if t.tool_type + in { + ToolType.CUSTOM, + ToolType.LETTA_CORE, + ToolType.LETTA_MEMORY_CORE, + ToolType.LETTA_MULTI_AGENT_CORE, + ToolType.LETTA_SLEEPTIME_CORE, + } or (t.tool_type == ToolType.LETTA_MULTI_AGENT_CORE and t.name == "send_message_to_agents_matching_tags") ] diff --git a/letta/groups/helpers.py b/letta/groups/helpers.py index 712dbcf8..039230df 100644 --- a/letta/groups/helpers.py +++ b/letta/groups/helpers.py @@ -88,11 +88,14 @@ def load_multi_agent( def stringify_message(message: Message, use_assistant_name: bool = False) -> str | None: assistant_name = message.name or "assistant" if use_assistant_name else "assistant" if message.role == "user": - content = json.loads(message.content[0].text) - if content["type"] == "user_message": - return f"{message.name or 'user'}: {content['message']}" - else: - return None + try: + content = json.loads(message.content[0].text) + if content["type"] == "user_message": + return f"{message.name or 'user'}: {content['message']}" + else: + return None + except: + return f"{message.name or 'user'}: {message.content[0].text}" elif message.role == "assistant": messages = [] if message.tool_calls: diff --git a/letta/groups/sleeptime_multi_agent.py b/letta/groups/sleeptime_multi_agent.py index 6114c552..6349b57b 100644 --- a/letta/groups/sleeptime_multi_agent.py +++ b/letta/groups/sleeptime_multi_agent.py @@ -107,6 +107,9 @@ class SleeptimeMultiAgent(Agent): run_id: str, ) -> LettaUsageStatistics: try: + job_update = JobUpdate(status=JobStatus.running) + self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.user) + participant_agent_state = self.agent_manager.get_agent_by_id(participant_agent_id, actor=self.user) participant_agent = Agent( agent_state=participant_agent_state, diff --git a/letta/groups/sleeptime_multi_agent_v2.py b/letta/groups/sleeptime_multi_agent_v2.py new file mode 100644 index 00000000..9dc591f5 --- /dev/null +++ b/letta/groups/sleeptime_multi_agent_v2.py @@ -0,0 +1,249 @@ +import asyncio +from datetime import datetime, timezone +from typing import AsyncGenerator, List, Optional + +from letta.agents.base_agent import BaseAgent +from letta.agents.letta_agent import LettaAgent +from letta.groups.helpers import stringify_message +from letta.schemas.enums import JobStatus +from letta.schemas.group import Group, ManagerType +from letta.schemas.job import JobUpdate +from letta.schemas.letta_message_content import TextContent +from letta.schemas.letta_response import LettaResponse +from letta.schemas.message import Message, MessageCreate +from letta.schemas.run import Run +from letta.schemas.user import User +from letta.services.agent_manager import AgentManager +from letta.services.block_manager import BlockManager +from letta.services.group_manager import GroupManager +from letta.services.job_manager import JobManager +from letta.services.message_manager import MessageManager +from letta.services.passage_manager import PassageManager + + +class SleeptimeMultiAgentV2(BaseAgent): + def __init__( + self, + agent_id: str, + message_manager: MessageManager, + agent_manager: AgentManager, + block_manager: BlockManager, + passage_manager: PassageManager, + group_manager: GroupManager, + job_manager: JobManager, + actor: User, + group: Optional[Group] = None, + ): + super().__init__( + agent_id=agent_id, + openai_client=None, + message_manager=message_manager, + agent_manager=agent_manager, + actor=actor, + ) + self.block_manager = block_manager + self.passage_manager = passage_manager + self.group_manager = group_manager + self.job_manager = job_manager + # Group settings + assert group.manager_type == ManagerType.sleeptime, f"Expected group manager type to be 'sleeptime', got {group.manager_type}" + self.group = group + + async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse: + run_ids = [] + + # Prepare new messages + new_messages = [] + for message in input_messages: + if isinstance(message.content, str): + message.content = [TextContent(text=message.content)] + message.group_id = self.group.id + new_messages.append(message) + + # Load foreground agent + foreground_agent = LettaAgent( + agent_id=self.agent_id, + message_manager=self.message_manager, + agent_manager=self.agent_manager, + block_manager=self.block_manager, + passage_manager=self.passage_manager, + actor=self.actor, + ) + # Perform foreground agent step + response = await foreground_agent.step(input_messages=new_messages, max_steps=max_steps) + + # Get last response messages + last_response_messages = foreground_agent.response_messages + + # Update turns counter + if self.group.sleeptime_agent_frequency is not None and self.group.sleeptime_agent_frequency > 0: + turns_counter = self.group_manager.bump_turns_counter(group_id=self.group.id, actor=self.actor) + + # Perform participant steps + if self.group.sleeptime_agent_frequency is None or ( + turns_counter is not None and turns_counter % self.group.sleeptime_agent_frequency == 0 + ): + last_processed_message_id = self.group_manager.get_last_processed_message_id_and_update( + group_id=self.group.id, last_processed_message_id=last_response_messages[-1].id, actor=self.actor + ) + for participant_agent_id in self.group.agent_ids: + try: + run_id = await self._issue_background_task( + participant_agent_id, + last_response_messages, + last_processed_message_id, + ) + run_ids.append(run_id) + + except Exception as e: + # Individual task failures + print(f"Agent processing failed: {str(e)}") + raise e + + response.usage.run_ids = run_ids + return response + + async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = 10) -> AsyncGenerator[str, None]: + # Prepare new messages + new_messages = [] + for message in input_messages: + if isinstance(message.content, str): + message.content = [TextContent(text=message.content)] + message.group_id = self.group.id + new_messages.append(message) + + # Load foreground agent + foreground_agent = LettaAgent( + agent_id=self.agent_id, + message_manager=self.message_manager, + agent_manager=self.agent_manager, + block_manager=self.block_manager, + passage_manager=self.passage_manager, + actor=self.actor, + ) + # Perform foreground agent step + async for chunk in foreground_agent.step_stream(input_messages=new_messages, max_steps=max_steps): + yield chunk + + # Get response messages + last_response_messages = foreground_agent.response_messages + + # Update turns counter + if self.group.sleeptime_agent_frequency is not None and self.group.sleeptime_agent_frequency > 0: + turns_counter = self.group_manager.bump_turns_counter(group_id=self.group.id, actor=self.actor) + + # Perform participant steps + if self.group.sleeptime_agent_frequency is None or ( + turns_counter is not None and turns_counter % self.group.sleeptime_agent_frequency == 0 + ): + last_processed_message_id = self.group_manager.get_last_processed_message_id_and_update( + group_id=self.group.id, last_processed_message_id=last_response_messages[-1].id, actor=self.actor + ) + for sleeptime_agent_id in self.group.agent_ids: + self._issue_background_task( + sleeptime_agent_id, + last_response_messages, + last_processed_message_id, + ) + + async def _issue_background_task( + self, + sleeptime_agent_id: str, + response_messages: List[Message], + last_processed_message_id: str, + ) -> str: + run = Run( + user_id=self.actor.id, + status=JobStatus.created, + metadata={ + "job_type": "sleeptime_agent_send_message_async", # is this right? + "agent_id": sleeptime_agent_id, + }, + ) + run = self.job_manager.create_job(pydantic_job=run, actor=self.actor) + + asyncio.create_task( + self._participant_agent_step( + foreground_agent_id=self.agent_id, + sleeptime_agent_id=sleeptime_agent_id, + response_messages=response_messages, + last_processed_message_id=last_processed_message_id, + run_id=run.id, + ) + ) + return run.id + + async def _participant_agent_step( + self, + foreground_agent_id: str, + sleeptime_agent_id: str, + response_messages: List[Message], + last_processed_message_id: str, + run_id: str, + ) -> str: + try: + # Update job status + job_update = JobUpdate(status=JobStatus.running) + self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor) + + # Create conversation transcript + prior_messages = [] + if self.group.sleeptime_agent_frequency: + try: + prior_messages = self.message_manager.list_messages_for_agent( + agent_id=foreground_agent_id, + actor=self.actor, + after=last_processed_message_id, + before=response_messages[0].id, + ) + except Exception: + pass # continue with just latest messages + + transcript_summary = [stringify_message(message) for message in prior_messages + response_messages] + transcript_summary = [summary for summary in transcript_summary if summary is not None] + message_text = "\n".join(transcript_summary) + + sleeptime_agent_messages = [ + MessageCreate( + role="user", + content=[TextContent(text=message_text)], + id=Message.generate_id(), + agent_id=sleeptime_agent_id, + group_id=self.group.id, + ) + ] + + # Load sleeptime agent + sleeptime_agent = LettaAgent( + agent_id=sleeptime_agent_id, + message_manager=self.message_manager, + agent_manager=self.agent_manager, + block_manager=self.block_manager, + passage_manager=self.passage_manager, + actor=self.actor, + ) + + # Perform sleeptime agent step + result = await sleeptime_agent.step( + input_messages=sleeptime_agent_messages, + ) + + # Update job status + job_update = JobUpdate( + status=JobStatus.completed, + completed_at=datetime.now(timezone.utc), + metadata={ + "result": result.model_dump(mode="json"), + "agent_id": sleeptime_agent_id, + }, + ) + self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor) + return result + except Exception as e: + job_update = JobUpdate( + status=JobStatus.failed, + completed_at=datetime.now(timezone.utc), + metadata={"error": str(e)}, + ) + self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor) + raise diff --git a/letta/services/tool_executor/tool_executor.py b/letta/services/tool_executor/tool_executor.py index a51b5e3c..7d9cac41 100644 --- a/letta/services/tool_executor/tool_executor.py +++ b/letta/services/tool_executor/tool_executor.py @@ -192,7 +192,7 @@ class LettaCoreToolExecutor(ToolExecutor): AgentManager().rebuild_system_prompt(agent_id=agent_state.id, actor=actor, force=True) return None - def core_memory_append(self, agent_state: "AgentState", actor: User, label: str, content: str) -> Optional[str]: + def core_memory_append(self, agent_state: AgentState, actor: User, label: str, content: str) -> Optional[str]: """ Append to the contents of core memory. @@ -211,7 +211,7 @@ class LettaCoreToolExecutor(ToolExecutor): def core_memory_replace( self, - agent_state: "AgentState", + agent_state: AgentState, actor: User, label: str, old_content: str, @@ -237,7 +237,8 @@ class LettaCoreToolExecutor(ToolExecutor): return None def memory_replace( - agent_state: "AgentState", + self, + agent_state: AgentState, actor: User, label: str, old_str: str, @@ -326,7 +327,8 @@ class LettaCoreToolExecutor(ToolExecutor): return success_msg def memory_insert( - agent_state: "AgentState", + self, + agent_state: AgentState, actor: User, label: str, new_str: str, @@ -407,7 +409,7 @@ class LettaCoreToolExecutor(ToolExecutor): return success_msg - def memory_rethink(agent_state: "AgentState", actor: User, label: str, new_memory: str) -> str: + def memory_rethink(self, agent_state: AgentState, actor: User, label: str, new_memory: str) -> str: """ The memory_rethink command allows you to completely rewrite the contents of a memory block. Use this tool to make large sweeping changes (e.g. when you want @@ -458,7 +460,7 @@ class LettaCoreToolExecutor(ToolExecutor): # return None return success_msg - def memory_finish_edits(agent_state: "AgentState") -> None: + def memory_finish_edits(self, agent_state: AgentState, actor: User) -> None: """ Call the memory_finish_edits command when you are finished making edits (integrating all new information) into the memory blocks. This function diff --git a/tests/integration_test_sleeptime_agent.py b/tests/integration_test_sleeptime_agent.py index 30bc3517..0749b399 100644 --- a/tests/integration_test_sleeptime_agent.py +++ b/tests/integration_test_sleeptime_agent.py @@ -5,6 +5,7 @@ from sqlalchemy import delete from letta.config import LettaConfig from letta.constants import DEFAULT_HUMAN +from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2 from letta.orm import Provider, Step from letta.orm.enums import JobType from letta.orm.errors import NoResultFound @@ -152,10 +153,132 @@ async def test_sleeptime_group_chat(server, actor): assert len(agent_runs) == len(run_ids) # 6. Verify run status after sleep - time.sleep(10) + time.sleep(2) for run_id in run_ids: job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor) - assert job.status == JobStatus.completed + assert job.status == JobStatus.running or job.status == JobStatus.completed + + # 7. Delete agent + server.agent_manager.delete_agent(agent_id=main_agent.id, actor=actor) + + with pytest.raises(NoResultFound): + server.group_manager.retrieve_group(group_id=group.id, actor=actor) + with pytest.raises(NoResultFound): + server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor) + + +@pytest.mark.asyncio +async def test_sleeptime_group_chat_v2(server, actor): + # 0. Refresh base tools + server.tool_manager.upsert_base_tools(actor=actor) + + # 1. Create sleeptime agent + main_agent = server.create_agent( + request=CreateAgent( + name="main_agent", + memory_blocks=[ + CreateBlock( + label="persona", + value="You are a personal assistant that helps users with requests.", + ), + CreateBlock( + label="human", + value="My favorite plant is the fiddle leaf\nMy favorite color is lavender", + ), + ], + # model="openai/gpt-4o-mini", + model="anthropic/claude-3-5-sonnet-20240620", + embedding="openai/text-embedding-ada-002", + enable_sleeptime=True, + ), + actor=actor, + ) + + assert main_agent.enable_sleeptime == True + main_agent_tools = [tool.name for tool in main_agent.tools] + assert "core_memory_append" not in main_agent_tools + assert "core_memory_replace" not in main_agent_tools + assert "archival_memory_insert" not in main_agent_tools + + # 2. Override frequency for test + group = server.group_manager.modify_group( + group_id=main_agent.multi_agent_group.id, + group_update=GroupUpdate( + manager_config=SleeptimeManagerUpdate( + sleeptime_agent_frequency=2, + ), + ), + actor=actor, + ) + + assert group.manager_type == ManagerType.sleeptime + assert group.sleeptime_agent_frequency == 2 + assert len(group.agent_ids) == 1 + + # 3. Verify shared blocks + sleeptime_agent_id = group.agent_ids[0] + shared_block = server.agent_manager.get_block_with_label(agent_id=main_agent.id, block_label="human", actor=actor) + agents = server.block_manager.get_agents_for_block(block_id=shared_block.id, actor=actor) + assert len(agents) == 2 + assert sleeptime_agent_id in [agent.id for agent in agents] + assert main_agent.id in [agent.id for agent in agents] + + # 4 Verify sleeptime agent tools + sleeptime_agent = server.agent_manager.get_agent_by_id(agent_id=sleeptime_agent_id, actor=actor) + sleeptime_agent_tools = [tool.name for tool in sleeptime_agent.tools] + assert "memory_rethink" in sleeptime_agent_tools + assert "memory_finish_edits" in sleeptime_agent_tools + assert "memory_replace" in sleeptime_agent_tools + assert "memory_insert" in sleeptime_agent_tools + + assert len([rule for rule in sleeptime_agent.tool_rules if rule.type == ToolRuleType.exit_loop]) > 0 + + # 5. Send messages and verify run ids + message_text = [ + "my favorite color is orange", + "not particularly. today is a good day", + "actually my favorite color is coral", + "let's change the subject", + "actually my fav plant is the the african spear", + "indeed", + ] + run_ids = [] + for i, text in enumerate(message_text): + agent = SleeptimeMultiAgentV2( + agent_id=main_agent.id, + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + passage_manager=server.passage_manager, + group_manager=server.group_manager, + job_manager=server.job_manager, + actor=actor, + group=main_agent.multi_agent_group, + ) + + response = await agent.step( + input_messages=[ + MessageCreate( + role="user", + content=text, + ), + ], + ) + + assert len(response.messages) > 0 + assert len(response.usage.run_ids or []) == (i + 1) % 2 + run_ids.extend(response.usage.run_ids or []) + + jobs = server.job_manager.list_jobs(actor=actor, job_type=JobType.RUN) + runs = [Run.from_job(job) for job in jobs] + agent_runs = [run for run in runs if "agent_id" in run.metadata and run.metadata["agent_id"] == sleeptime_agent_id] + assert len(agent_runs) == len(run_ids) + + # 6. Verify run status after sleep + time.sleep(2) + for run_id in run_ids: + job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor) + assert job.status == JobStatus.running or job.status == JobStatus.completed # 7. Delete agent server.agent_manager.delete_agent(agent_id=main_agent.id, actor=actor)