From 58ec6238a861e60f8e4f43216c0e4f51552fb3da Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Sat, 30 Nov 2024 09:57:52 -0800 Subject: [PATCH] feat: Improve retry mechanism for `_get_ai_reply` and refactor method (#2113) --- letta/agent.py | 89 +++++++++++++++++++++++++------------------------- 1 file changed, 45 insertions(+), 44 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 73f0199c..ebd73aa9 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -1,5 +1,6 @@ import datetime import inspect +import time import traceback import warnings from abc import ABC, abstractmethod @@ -566,60 +567,60 @@ class Agent(BaseAgent): self, message_sequence: List[Message], function_call: str = "auto", - first_message: bool = False, # hint + first_message: bool = False, stream: bool = False, # TODO move to config? - fail_on_empty_response: bool = False, empty_response_retry_limit: int = 3, + backoff_factor: float = 0.5, # delay multiplier for exponential backoff + max_delay: float = 10.0, # max delay between retries ) -> ChatCompletionResponse: - """Get response from LLM API""" - # Get the allowed tools based on the ToolRulesSolver state + """Get response from LLM API with robust retry mechanism.""" + allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names() + allowed_functions = ( + self.functions if not allowed_tool_names else [func for func in self.functions if func["name"] in allowed_tool_names] + ) - if not allowed_tool_names: - # if it's empty, any available tools are fair game - allowed_functions = self.functions - else: - allowed_functions = [func for func in self.functions if func["name"] in allowed_tool_names] + for attempt in range(1, empty_response_retry_limit + 1): + try: + response = create( + llm_config=self.agent_state.llm_config, + messages=message_sequence, + user_id=self.agent_state.user_id, + functions=allowed_functions, + functions_python=self.functions_python, + function_call=function_call, + first_message=first_message, + stream=stream, + stream_interface=self.interface, + ) - try: - response = create( - # agent_state=self.agent_state, - llm_config=self.agent_state.llm_config, - messages=message_sequence, - user_id=self.agent_state.user_id, - functions=allowed_functions, - functions_python=self.functions_python, - function_call=function_call, - # hint - first_message=first_message, - # streaming - stream=stream, - stream_interface=self.interface, - ) + # These bottom two are retryable + if len(response.choices) == 0 or response.choices[0] is None: + raise ValueError(f"API call returned an empty message: {response}") - if len(response.choices) == 0 or response.choices[0] is None: - empty_api_err_message = f"API call didn't return a message: {response}" - if fail_on_empty_response or empty_response_retry_limit == 0: - raise Exception(empty_api_err_message) + if response.choices[0].finish_reason not in ["stop", "function_call", "tool_calls"]: + if response.choices[0].finish_reason == "length": + # This is not retryable, hence RuntimeError v.s. ValueError + raise RuntimeError("Finish reason was length (maximum context length)") + else: + raise ValueError(f"Bad finish reason from API: {response.choices[0].finish_reason}") + + return response + + except ValueError as ve: + if attempt >= empty_response_retry_limit: + warnings.warn(f"Retry limit reached. Final error: {ve}") + break else: - # Decrement retry limit and try again - warnings.warn(empty_api_err_message) - return self._get_ai_reply( - message_sequence, function_call, first_message, stream, fail_on_empty_response, empty_response_retry_limit - 1 - ) + delay = min(backoff_factor * (2 ** (attempt - 1)), max_delay) + warnings.warn(f"Attempt {attempt} failed: {ve}. Retrying in {delay} seconds...") + time.sleep(delay) - # special case for 'length' - if response.choices[0].finish_reason == "length": - raise Exception("Finish reason was length (maximum context length)") + except Exception as e: + # For non-retryable errors, exit immediately + raise e - # catches for soft errors - if response.choices[0].finish_reason not in ["stop", "function_call", "tool_calls"]: - raise Exception(f"API call finish with bad finish reason: {response}") - - # unpack with response.choices[0].message.content - return response - except Exception as e: - raise e + raise Exception("Retries exhausted and no valid response received.") def _handle_ai_response( self,