feat: Improve retry mechanism for _get_ai_reply and refactor method (#2113)

This commit is contained in:
Matthew Zhou
2024-11-30 09:57:52 -08:00
committed by GitHub
parent c0329632d5
commit 58ec6238a8

View File

@@ -1,5 +1,6 @@
import datetime
import inspect
import time
import traceback
import warnings
from abc import ABC, abstractmethod
@@ -566,60 +567,60 @@ class Agent(BaseAgent):
self,
message_sequence: List[Message],
function_call: str = "auto",
first_message: bool = False, # hint
first_message: bool = False,
stream: bool = False, # TODO move to config?
fail_on_empty_response: bool = False,
empty_response_retry_limit: int = 3,
backoff_factor: float = 0.5, # delay multiplier for exponential backoff
max_delay: float = 10.0, # max delay between retries
) -> ChatCompletionResponse:
"""Get response from LLM API"""
# Get the allowed tools based on the ToolRulesSolver state
"""Get response from LLM API with robust retry mechanism."""
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names()
allowed_functions = (
self.functions if not allowed_tool_names else [func for func in self.functions if func["name"] in allowed_tool_names]
)
if not allowed_tool_names:
# if it's empty, any available tools are fair game
allowed_functions = self.functions
else:
allowed_functions = [func for func in self.functions if func["name"] in allowed_tool_names]
for attempt in range(1, empty_response_retry_limit + 1):
try:
response = create(
llm_config=self.agent_state.llm_config,
messages=message_sequence,
user_id=self.agent_state.user_id,
functions=allowed_functions,
functions_python=self.functions_python,
function_call=function_call,
first_message=first_message,
stream=stream,
stream_interface=self.interface,
)
try:
response = create(
# agent_state=self.agent_state,
llm_config=self.agent_state.llm_config,
messages=message_sequence,
user_id=self.agent_state.user_id,
functions=allowed_functions,
functions_python=self.functions_python,
function_call=function_call,
# hint
first_message=first_message,
# streaming
stream=stream,
stream_interface=self.interface,
)
# These bottom two are retryable
if len(response.choices) == 0 or response.choices[0] is None:
raise ValueError(f"API call returned an empty message: {response}")
if len(response.choices) == 0 or response.choices[0] is None:
empty_api_err_message = f"API call didn't return a message: {response}"
if fail_on_empty_response or empty_response_retry_limit == 0:
raise Exception(empty_api_err_message)
if response.choices[0].finish_reason not in ["stop", "function_call", "tool_calls"]:
if response.choices[0].finish_reason == "length":
# This is not retryable, hence RuntimeError v.s. ValueError
raise RuntimeError("Finish reason was length (maximum context length)")
else:
raise ValueError(f"Bad finish reason from API: {response.choices[0].finish_reason}")
return response
except ValueError as ve:
if attempt >= empty_response_retry_limit:
warnings.warn(f"Retry limit reached. Final error: {ve}")
break
else:
# Decrement retry limit and try again
warnings.warn(empty_api_err_message)
return self._get_ai_reply(
message_sequence, function_call, first_message, stream, fail_on_empty_response, empty_response_retry_limit - 1
)
delay = min(backoff_factor * (2 ** (attempt - 1)), max_delay)
warnings.warn(f"Attempt {attempt} failed: {ve}. Retrying in {delay} seconds...")
time.sleep(delay)
# special case for 'length'
if response.choices[0].finish_reason == "length":
raise Exception("Finish reason was length (maximum context length)")
except Exception as e:
# For non-retryable errors, exit immediately
raise e
# catches for soft errors
if response.choices[0].finish_reason not in ["stop", "function_call", "tool_calls"]:
raise Exception(f"API call finish with bad finish reason: {response}")
# unpack with response.choices[0].message.content
return response
except Exception as e:
raise e
raise Exception("Retries exhausted and no valid response received.")
def _handle_ai_response(
self,