From 86b8216adfdbc492f00d2944508ebfd4215f8b41 Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 13 May 2025 22:36:56 -0700 Subject: [PATCH] feat: stream tool return in async loop (#2172) --- letta/agents/letta_agent.py | 13 +++++++------ letta/agents/letta_agent_batch.py | 2 -- letta/agents/voice_sleeptime_agent.py | 6 +++--- letta/server/rest_api/routers/v1/agents.py | 2 +- 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 2b20cfaf..8aca8e1c 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -47,7 +47,6 @@ class LettaAgent(BaseAgent): block_manager: BlockManager, passage_manager: PassageManager, actor: User, - use_assistant_message: bool = True, ): super().__init__(agent_id=agent_id, openai_client=None, message_manager=message_manager, agent_manager=agent_manager, actor=actor) @@ -55,7 +54,6 @@ class LettaAgent(BaseAgent): # Summarizer settings self.block_manager = block_manager self.passage_manager = passage_manager - self.use_assistant_message = use_assistant_message self.response_messages: List[Message] = [] self.last_function_response = self._load_last_function_response() @@ -65,12 +63,12 @@ class LettaAgent(BaseAgent): self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_id) @trace_method - async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse: + async def step(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True) -> LettaResponse: agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor) current_in_context_messages, new_in_context_messages = await self._step( agent_state=agent_state, input_messages=input_messages, max_steps=max_steps ) - return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=self.use_assistant_message) + return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message) async def _step( self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10 @@ -112,7 +110,7 @@ class LettaAgent(BaseAgent): @trace_method async def step_stream( - self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = False + self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True ) -> AsyncGenerator[str, None]: """ Main streaming loop that yields partial tokens. @@ -160,6 +158,10 @@ class LettaAgent(BaseAgent): self.response_messages.extend(persisted_messages) new_in_context_messages.extend(persisted_messages) + if not use_assistant_message or should_continue: + tool_return = persisted_messages[-1].to_letta_messages()[0] + yield f"data: {tool_return.model_dump_json()}\n\n" + if not should_continue: break @@ -359,7 +361,6 @@ class LettaAgent(BaseAgent): block_manager=self.block_manager, passage_manager=self.passage_manager, actor=self.actor, - use_assistant_message=True, ) augmented_message = ( diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index cbe17e1a..ba426688 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -107,7 +107,6 @@ class LettaAgentBatch(BaseAgent): sandbox_config_manager: SandboxConfigManager, job_manager: JobManager, actor: User, - use_assistant_message: bool = True, max_steps: int = 10, ): self.message_manager = message_manager @@ -117,7 +116,6 @@ class LettaAgentBatch(BaseAgent): self.batch_manager = batch_manager self.sandbox_config_manager = sandbox_config_manager self.job_manager = job_manager - self.use_assistant_message = use_assistant_message self.actor = actor self.max_steps = max_steps diff --git a/letta/agents/voice_sleeptime_agent.py b/letta/agents/voice_sleeptime_agent.py index d3e2b70c..9ed3bc26 100644 --- a/letta/agents/voice_sleeptime_agent.py +++ b/letta/agents/voice_sleeptime_agent.py @@ -58,7 +58,7 @@ class VoiceSleeptimeAgent(LettaAgent): def update_message_transcript(self, message_transcripts: List[str]): self.message_transcripts = message_transcripts - async def step(self, input_messages: List[MessageCreate], max_steps: int = 20) -> LettaResponse: + async def step(self, input_messages: List[MessageCreate], max_steps: int = 20, use_assistant_message: bool = True) -> LettaResponse: """ Process the user's input message, allowing the model to call memory-related tools until it decides to stop and provide a final response. @@ -84,7 +84,7 @@ class VoiceSleeptimeAgent(LettaAgent): agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor ) - return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=self.use_assistant_message) + return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message) @trace_method async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]: @@ -146,7 +146,7 @@ class VoiceSleeptimeAgent(LettaAgent): return f"Failed to store memory given start_index {start_index} and end_index {end_index}: {e}", False async def step_stream( - self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = False + self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True ) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]: """ This agent is synchronous-only. If called in an async context, raise an error. diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 113beb68..226aba3b 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -646,7 +646,7 @@ async def send_message( actor=actor, ) - result = await experimental_agent.step(request.messages, max_steps=10) + result = await experimental_agent.step(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message) else: result = await server.send_message_to_agent( agent_id=agent_id,