feat: stream tool return in async loop (#2172)
This commit is contained in:
@@ -47,7 +47,6 @@ class LettaAgent(BaseAgent):
|
||||
block_manager: BlockManager,
|
||||
passage_manager: PassageManager,
|
||||
actor: User,
|
||||
use_assistant_message: bool = True,
|
||||
):
|
||||
super().__init__(agent_id=agent_id, openai_client=None, message_manager=message_manager, agent_manager=agent_manager, actor=actor)
|
||||
|
||||
@@ -55,7 +54,6 @@ class LettaAgent(BaseAgent):
|
||||
# Summarizer settings
|
||||
self.block_manager = block_manager
|
||||
self.passage_manager = passage_manager
|
||||
self.use_assistant_message = use_assistant_message
|
||||
self.response_messages: List[Message] = []
|
||||
|
||||
self.last_function_response = self._load_last_function_response()
|
||||
@@ -65,12 +63,12 @@ class LettaAgent(BaseAgent):
|
||||
self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_id)
|
||||
|
||||
@trace_method
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True) -> LettaResponse:
|
||||
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
|
||||
current_in_context_messages, new_in_context_messages = await self._step(
|
||||
agent_state=agent_state, input_messages=input_messages, max_steps=max_steps
|
||||
)
|
||||
return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=self.use_assistant_message)
|
||||
return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message)
|
||||
|
||||
async def _step(
|
||||
self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10
|
||||
@@ -112,7 +110,7 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
@trace_method
|
||||
async def step_stream(
|
||||
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = False
|
||||
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Main streaming loop that yields partial tokens.
|
||||
@@ -160,6 +158,10 @@ class LettaAgent(BaseAgent):
|
||||
self.response_messages.extend(persisted_messages)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
|
||||
if not use_assistant_message or should_continue:
|
||||
tool_return = persisted_messages[-1].to_letta_messages()[0]
|
||||
yield f"data: {tool_return.model_dump_json()}\n\n"
|
||||
|
||||
if not should_continue:
|
||||
break
|
||||
|
||||
@@ -359,7 +361,6 @@ class LettaAgent(BaseAgent):
|
||||
block_manager=self.block_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=self.actor,
|
||||
use_assistant_message=True,
|
||||
)
|
||||
|
||||
augmented_message = (
|
||||
|
||||
@@ -107,7 +107,6 @@ class LettaAgentBatch(BaseAgent):
|
||||
sandbox_config_manager: SandboxConfigManager,
|
||||
job_manager: JobManager,
|
||||
actor: User,
|
||||
use_assistant_message: bool = True,
|
||||
max_steps: int = 10,
|
||||
):
|
||||
self.message_manager = message_manager
|
||||
@@ -117,7 +116,6 @@ class LettaAgentBatch(BaseAgent):
|
||||
self.batch_manager = batch_manager
|
||||
self.sandbox_config_manager = sandbox_config_manager
|
||||
self.job_manager = job_manager
|
||||
self.use_assistant_message = use_assistant_message
|
||||
self.actor = actor
|
||||
self.max_steps = max_steps
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ class VoiceSleeptimeAgent(LettaAgent):
|
||||
def update_message_transcript(self, message_transcripts: List[str]):
|
||||
self.message_transcripts = message_transcripts
|
||||
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 20) -> LettaResponse:
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 20, use_assistant_message: bool = True) -> LettaResponse:
|
||||
"""
|
||||
Process the user's input message, allowing the model to call memory-related tools
|
||||
until it decides to stop and provide a final response.
|
||||
@@ -84,7 +84,7 @@ class VoiceSleeptimeAgent(LettaAgent):
|
||||
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
|
||||
)
|
||||
|
||||
return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=self.use_assistant_message)
|
||||
return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message)
|
||||
|
||||
@trace_method
|
||||
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]:
|
||||
@@ -146,7 +146,7 @@ class VoiceSleeptimeAgent(LettaAgent):
|
||||
return f"Failed to store memory given start_index {start_index} and end_index {end_index}: {e}", False
|
||||
|
||||
async def step_stream(
|
||||
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = False
|
||||
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True
|
||||
) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
|
||||
"""
|
||||
This agent is synchronous-only. If called in an async context, raise an error.
|
||||
|
||||
@@ -646,7 +646,7 @@ async def send_message(
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
result = await experimental_agent.step(request.messages, max_steps=10)
|
||||
result = await experimental_agent.step(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message)
|
||||
else:
|
||||
result = await server.send_message_to_agent(
|
||||
agent_id=agent_id,
|
||||
|
||||
Reference in New Issue
Block a user