feat: support ConditionalToolRule in new agent loop (#1977)

This commit is contained in:
Andy Li
2025-05-12 10:43:47 -07:00
committed by GitHub
parent 314d30cb8f
commit 01b80d7b2b

View File

@@ -32,6 +32,7 @@ from letta.services.helpers.agent_manager_helper import compile_system_message
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
from letta.system import package_function_response
from letta.tracing import log_event, trace_method
from letta.utils import united_diff
@@ -59,6 +60,8 @@ class LettaAgent(BaseAgent):
self.use_assistant_message = use_assistant_message
self.response_messages: List[Message] = []
self.last_function_response = self._load_last_function_response()
@trace_method
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
@@ -194,7 +197,12 @@ class LettaAgent(BaseAgent):
or (t.tool_type == ToolType.EXTERNAL_COMPOSIO)
]
valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools]))
# Mirror the sync agent loop: get allowed tools or allow all if none are allowed
valid_tool_names = tool_rules_solver.get_allowed_tool_names(
available_tools=set([t.name for t in tools]),
last_function_response=self.last_function_response,
) or list(set(t.name for t in tools))
# TODO: Copied from legacy agent loop, so please be cautious
# Set force tool
force_tool_call = None
@@ -255,6 +263,7 @@ class LettaAgent(BaseAgent):
tool_args=tool_args,
agent_state=agent_state,
)
function_response = package_function_response(tool_result, success_flag)
# 4. Register tool call with tool rule solver
# Resolve whether or not to continue stepping
@@ -283,6 +292,7 @@ class LettaAgent(BaseAgent):
pre_computed_tool_message_id=pre_computed_tool_message_id,
)
persisted_messages = self.message_manager.create_many_messages(tool_call_messages, actor=self.actor)
self.last_function_response = function_response
return persisted_messages, continue_stepping
@@ -348,10 +358,6 @@ class LettaAgent(BaseAgent):
results = await self._send_message_to_agents_matching_tags(**tool_args)
log_event(name="finish_send_message_to_agents_matching_tags", attributes=tool_args)
return json.dumps(results), True
elif target_tool.tool_type == ToolType.EXTERNAL_COMPOSIO:
log_event(name=f"start_composio_{tool_name}_execution", attributes=tool_args)
log_event(name=f"finish_compsio_{tool_name}_execution", attributes=tool_args)
return tool_execution_result.func_return, True
else:
tool_execution_manager = ToolExecutionManager(agent_state=agent_state, actor=self.actor)
# TODO: Integrate sandbox result
@@ -416,3 +422,17 @@ class LettaAgent(BaseAgent):
tasks = [asyncio.create_task(process_agent(agent_state=agent_state, message=message)) for agent_state in matching_agents]
results = await asyncio.gather(*tasks)
return results
def _load_last_function_response(self):
"""Load the last function response from message history"""
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_id, actor=self.actor)
for msg in reversed(in_context_messages):
if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and isinstance(msg.content[0], TextContent):
text_content = msg.content[0].text
try:
response_json = json.loads(text_content)
if response_json.get("message"):
return response_json["message"]
except (json.JSONDecodeError, KeyError):
raise ValueError(f"Invalid JSON format in message: {text_content}")
return None