diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 9dba906d..ca7e1fb7 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -32,6 +32,7 @@ from letta.services.helpers.agent_manager_helper import compile_system_message from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager +from letta.system import package_function_response from letta.tracing import log_event, trace_method from letta.utils import united_diff @@ -59,6 +60,8 @@ class LettaAgent(BaseAgent): self.use_assistant_message = use_assistant_message self.response_messages: List[Message] = [] + self.last_function_response = self._load_last_function_response() + @trace_method async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse: agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor) @@ -194,7 +197,12 @@ class LettaAgent(BaseAgent): or (t.tool_type == ToolType.EXTERNAL_COMPOSIO) ] - valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools])) + # Mirror the sync agent loop: get allowed tools or allow all if none are allowed + valid_tool_names = tool_rules_solver.get_allowed_tool_names( + available_tools=set([t.name for t in tools]), + last_function_response=self.last_function_response, + ) or list(set(t.name for t in tools)) + # TODO: Copied from legacy agent loop, so please be cautious # Set force tool force_tool_call = None @@ -255,6 +263,7 @@ class LettaAgent(BaseAgent): tool_args=tool_args, agent_state=agent_state, ) + function_response = package_function_response(tool_result, success_flag) # 4. Register tool call with tool rule solver # Resolve whether or not to continue stepping @@ -283,6 +292,7 @@ class LettaAgent(BaseAgent): pre_computed_tool_message_id=pre_computed_tool_message_id, ) persisted_messages = self.message_manager.create_many_messages(tool_call_messages, actor=self.actor) + self.last_function_response = function_response return persisted_messages, continue_stepping @@ -348,10 +358,6 @@ class LettaAgent(BaseAgent): results = await self._send_message_to_agents_matching_tags(**tool_args) log_event(name="finish_send_message_to_agents_matching_tags", attributes=tool_args) return json.dumps(results), True - elif target_tool.tool_type == ToolType.EXTERNAL_COMPOSIO: - log_event(name=f"start_composio_{tool_name}_execution", attributes=tool_args) - log_event(name=f"finish_compsio_{tool_name}_execution", attributes=tool_args) - return tool_execution_result.func_return, True else: tool_execution_manager = ToolExecutionManager(agent_state=agent_state, actor=self.actor) # TODO: Integrate sandbox result @@ -416,3 +422,17 @@ class LettaAgent(BaseAgent): tasks = [asyncio.create_task(process_agent(agent_state=agent_state, message=message)) for agent_state in matching_agents] results = await asyncio.gather(*tasks) return results + + def _load_last_function_response(self): + """Load the last function response from message history""" + in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_id, actor=self.actor) + for msg in reversed(in_context_messages): + if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and isinstance(msg.content[0], TextContent): + text_content = msg.content[0].text + try: + response_json = json.loads(text_content) + if response_json.get("message"): + return response_json["message"] + except (json.JSONDecodeError, KeyError): + raise ValueError(f"Invalid JSON format in message: {text_content}") + return None