feat: support ConditionalToolRule in new agent loop (#1977)
This commit is contained in:
@@ -32,6 +32,7 @@ from letta.services.helpers.agent_manager_helper import compile_system_message
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
|
||||
from letta.system import package_function_response
|
||||
from letta.tracing import log_event, trace_method
|
||||
from letta.utils import united_diff
|
||||
|
||||
@@ -59,6 +60,8 @@ class LettaAgent(BaseAgent):
|
||||
self.use_assistant_message = use_assistant_message
|
||||
self.response_messages: List[Message] = []
|
||||
|
||||
self.last_function_response = self._load_last_function_response()
|
||||
|
||||
@trace_method
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
|
||||
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
|
||||
@@ -194,7 +197,12 @@ class LettaAgent(BaseAgent):
|
||||
or (t.tool_type == ToolType.EXTERNAL_COMPOSIO)
|
||||
]
|
||||
|
||||
valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools]))
|
||||
# Mirror the sync agent loop: get allowed tools or allow all if none are allowed
|
||||
valid_tool_names = tool_rules_solver.get_allowed_tool_names(
|
||||
available_tools=set([t.name for t in tools]),
|
||||
last_function_response=self.last_function_response,
|
||||
) or list(set(t.name for t in tools))
|
||||
|
||||
# TODO: Copied from legacy agent loop, so please be cautious
|
||||
# Set force tool
|
||||
force_tool_call = None
|
||||
@@ -255,6 +263,7 @@ class LettaAgent(BaseAgent):
|
||||
tool_args=tool_args,
|
||||
agent_state=agent_state,
|
||||
)
|
||||
function_response = package_function_response(tool_result, success_flag)
|
||||
|
||||
# 4. Register tool call with tool rule solver
|
||||
# Resolve whether or not to continue stepping
|
||||
@@ -283,6 +292,7 @@ class LettaAgent(BaseAgent):
|
||||
pre_computed_tool_message_id=pre_computed_tool_message_id,
|
||||
)
|
||||
persisted_messages = self.message_manager.create_many_messages(tool_call_messages, actor=self.actor)
|
||||
self.last_function_response = function_response
|
||||
|
||||
return persisted_messages, continue_stepping
|
||||
|
||||
@@ -348,10 +358,6 @@ class LettaAgent(BaseAgent):
|
||||
results = await self._send_message_to_agents_matching_tags(**tool_args)
|
||||
log_event(name="finish_send_message_to_agents_matching_tags", attributes=tool_args)
|
||||
return json.dumps(results), True
|
||||
elif target_tool.tool_type == ToolType.EXTERNAL_COMPOSIO:
|
||||
log_event(name=f"start_composio_{tool_name}_execution", attributes=tool_args)
|
||||
log_event(name=f"finish_compsio_{tool_name}_execution", attributes=tool_args)
|
||||
return tool_execution_result.func_return, True
|
||||
else:
|
||||
tool_execution_manager = ToolExecutionManager(agent_state=agent_state, actor=self.actor)
|
||||
# TODO: Integrate sandbox result
|
||||
@@ -416,3 +422,17 @@ class LettaAgent(BaseAgent):
|
||||
tasks = [asyncio.create_task(process_agent(agent_state=agent_state, message=message)) for agent_state in matching_agents]
|
||||
results = await asyncio.gather(*tasks)
|
||||
return results
|
||||
|
||||
def _load_last_function_response(self):
|
||||
"""Load the last function response from message history"""
|
||||
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_id, actor=self.actor)
|
||||
for msg in reversed(in_context_messages):
|
||||
if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and isinstance(msg.content[0], TextContent):
|
||||
text_content = msg.content[0].text
|
||||
try:
|
||||
response_json = json.loads(text_content)
|
||||
if response_json.get("message"):
|
||||
return response_json["message"]
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
raise ValueError(f"Invalid JSON format in message: {text_content}")
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user