From 734680db811d583f955260281239c63106a1ee10 Mon Sep 17 00:00:00 2001 From: Andy Li <55300002+cliandy@users.noreply.github.com> Date: Mon, 23 Jun 2025 16:55:23 -0700 Subject: [PATCH] feat: timeout configuration for LLM clients + vertex (#2972) --- letta/agents/letta_agent.py | 2 +- letta/errors.py | 5 +++++ letta/llm_api/anthropic_client.py | 11 ++++++++++ letta/llm_api/google_ai_client.py | 10 +++++++--- letta/llm_api/google_vertex_client.py | 20 ++++++++++++++++--- letta/llm_api/openai_client.py | 13 ++++++++++++ .../services/helpers/agent_manager_helper.py | 8 ++++---- letta/settings.py | 14 ++++--------- tests/test_utils.py | 4 +++- 9 files changed, 65 insertions(+), 22 deletions(-) diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index f1a2b794..4fba688d 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -270,7 +270,7 @@ class LettaAgent(BaseAgent): if include_return_message_types is None or message.message_type in include_return_message_types: yield f"data: {message.model_dump_json()}\n\n" - MetricRegistry().step_execution_time_ms_histogram.record(step_start - get_utc_timestamp_ns(), get_ctx_attributes()) + MetricRegistry().step_execution_time_ms_histogram.record(get_utc_timestamp_ns() - step_start, get_ctx_attributes()) if not should_continue: break diff --git a/letta/errors.py b/letta/errors.py index 17427ea6..1e2f013f 100644 --- a/letta/errors.py +++ b/letta/errors.py @@ -17,6 +17,7 @@ class ErrorCode(Enum): INTERNAL_SERVER_ERROR = "INTERNAL_SERVER_ERROR" CONTEXT_WINDOW_EXCEEDED = "CONTEXT_WINDOW_EXCEEDED" RATE_LIMIT_EXCEEDED = "RATE_LIMIT_EXCEEDED" + TIMEOUT = "TIMEOUT" class LettaError(Exception): @@ -101,6 +102,10 @@ class LLMServerError(LLMError): while processing the request.""" +class LLMTimeoutError(LLMError): + """Error when LLM request times out""" + + class BedrockPermissionError(LettaError): """Exception raised for errors in the Bedrock permission process.""" diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index e8bb91fa..9e940fe4 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -21,9 +21,11 @@ from letta.errors import ( LLMPermissionDeniedError, LLMRateLimitError, LLMServerError, + LLMTimeoutError, LLMUnprocessableEntityError, ) from letta.helpers.datetime_helpers import get_utc_time_int +from letta.helpers.decorators import deprecated from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_inner_thoughts_from_kwargs from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION @@ -47,6 +49,7 @@ logger = get_logger(__name__) class AnthropicClient(LLMClientBase): @trace_method + @deprecated("Synchronous version of this is no longer valid. Will result in model_dump of coroutine") def request(self, request_data: dict, llm_config: LLMConfig) -> dict: client = self._get_anthropic_client(llm_config, async_client=False) response = client.beta.messages.create(**request_data) @@ -298,6 +301,14 @@ class AnthropicClient(LLMClientBase): @trace_method def handle_llm_error(self, e: Exception) -> Exception: + if isinstance(e, anthropic.APITimeoutError): + logger.warning(f"[Anthropic] Request timeout: {e}") + return LLMTimeoutError( + message=f"Request to Anthropic timed out: {str(e)}", + code=ErrorCode.TIMEOUT, + details={"cause": str(e.__cause__) if e.__cause__ else None}, + ) + if isinstance(e, anthropic.APIConnectionError): logger.warning(f"[Anthropic] API connection error: {e.__cause__}") return LLMConnectionError( diff --git a/letta/llm_api/google_ai_client.py b/letta/llm_api/google_ai_client.py index 47671398..a8fa03f4 100644 --- a/letta/llm_api/google_ai_client.py +++ b/letta/llm_api/google_ai_client.py @@ -2,20 +2,24 @@ from typing import List, Optional, Tuple import httpx from google import genai +from google.genai.types import HttpOptions from letta.errors import ErrorCode, LLMAuthenticationError, LLMError from letta.llm_api.google_constants import GOOGLE_MODEL_FOR_API_KEY_CHECK from letta.llm_api.google_vertex_client import GoogleVertexClient from letta.log import get_logger -from letta.settings import model_settings +from letta.settings import model_settings, settings logger = get_logger(__name__) class GoogleAIClient(GoogleVertexClient): - def _get_client(self): - return genai.Client(api_key=model_settings.gemini_api_key) + timeout_ms = int(settings.llm_request_timeout_seconds * 1000) + return genai.Client( + api_key=model_settings.gemini_api_key, + http_options=HttpOptions(timeout=timeout_ms), + ) def get_gemini_endpoint_and_headers( diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index c7c49ffa..a3fb9e04 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -3,7 +3,14 @@ import uuid from typing import List, Optional from google import genai -from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, GenerateContentResponse, ThinkingConfig, ToolConfig +from google.genai.types import ( + FunctionCallingConfig, + FunctionCallingConfigMode, + GenerateContentResponse, + HttpOptions, + ThinkingConfig, + ToolConfig, +) from letta.constants import NON_USER_MSG_PREFIX from letta.helpers.datetime_helpers import get_utc_time_int @@ -26,11 +33,12 @@ logger = get_logger(__name__) class GoogleVertexClient(LLMClientBase): def _get_client(self): + timeout_ms = int(settings.llm_request_timeout_seconds * 1000) return genai.Client( vertexai=True, project=model_settings.google_cloud_project, location=model_settings.google_cloud_location, - http_options={"api_version": "v1"}, + http_options=HttpOptions(api_version="v1", timeout=timeout_ms), ) @trace_method @@ -59,7 +67,8 @@ class GoogleVertexClient(LLMClientBase): ) return response.model_dump() - def add_dummy_model_messages(self, messages: List[dict]) -> List[dict]: + @staticmethod + def add_dummy_model_messages(messages: List[dict]) -> List[dict]: """Google AI API requires all function call returns are immediately followed by a 'model' role message. In Letta, the 'model' will often call a function (e.g. send_message) that itself yields to the user, @@ -484,3 +493,8 @@ class GoogleVertexClient(LLMClientBase): "propertyOrdering": ["name", "args"], "required": ["name", "args"], } + + @trace_method + def handle_llm_error(self, e: Exception) -> Exception: + # Fallback to base implementation + return super().handle_llm_error(e) diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index 7d8e2678..849b2851 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -17,6 +17,7 @@ from letta.errors import ( LLMPermissionDeniedError, LLMRateLimitError, LLMServerError, + LLMTimeoutError, LLMUnprocessableEntityError, ) from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, unpack_all_inner_thoughts_from_kwargs @@ -317,6 +318,18 @@ class OpenAIClient(LLMClientBase): """ Maps OpenAI-specific errors to common LLMError types. """ + if isinstance(e, openai.APITimeoutError): + timeout_duration = getattr(e, "timeout", "unknown") + logger.warning(f"[OpenAI] Request timeout after {timeout_duration} seconds: {e}") + return LLMTimeoutError( + message=f"Request to OpenAI timed out: {str(e)}", + code=ErrorCode.TIMEOUT, + details={ + "timeout_duration": timeout_duration, + "cause": str(e.__cause__) if e.__cause__ else None, + }, + ) + if isinstance(e, openai.APIConnectionError): logger.warning(f"[OpenAI] API connection error: {e}") return LLMConnectionError( diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index 249e807e..c48dfd39 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -1,4 +1,4 @@ -import datetime +from datetime import datetime from typing import List, Literal, Optional import numpy as np @@ -178,7 +178,7 @@ def derive_system_message(agent_type: AgentType, enable_sleeptime: Optional[bool # TODO: This code is kind of wonky and deserves a rewrite def compile_memory_metadata_block( - memory_edit_timestamp: datetime.datetime, + memory_edit_timestamp: datetime, previous_message_count: int = 0, archival_memory_size: int = 0, ) -> str: @@ -223,7 +223,7 @@ def safe_format(template: str, variables: dict) -> str: def compile_system_message( system_prompt: str, in_context_memory: Memory, - in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory? + in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory? user_defined_variables: Optional[dict] = None, append_icm_if_missing: bool = True, template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", @@ -292,7 +292,7 @@ def compile_system_message( def initialize_message_sequence( agent_state: AgentState, - memory_edit_timestamp: Optional[datetime.datetime] = None, + memory_edit_timestamp: Optional[datetime] = None, include_initial_boot_message: bool = True, previous_message_count: int = 0, archival_memory_size: int = 0, diff --git a/letta/settings.py b/letta/settings.py index eebc2919..5fbe0cc8 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -230,16 +230,6 @@ class Settings(BaseSettings): use_experimental: bool = False use_vertex_structured_outputs_experimental: bool = False - # LLM provider client settings - httpx_max_retries: int = 5 - httpx_timeout_connect: float = 10.0 - httpx_timeout_read: float = 60.0 - httpx_timeout_write: float = 30.0 - httpx_timeout_pool: float = 10.0 - httpx_max_connections: int = 500 - httpx_max_keepalive_connections: int = 500 - httpx_keepalive_expiry: float = 120.0 - # cron job parameters enable_batch_job_polling: bool = False poll_running_llm_batches_interval_seconds: int = 5 * 60 @@ -250,6 +240,10 @@ class Settings(BaseSettings): # for OCR mistral_api_key: Optional[str] = None + # LLM request timeout settings (model + embedding model) + llm_request_timeout_seconds: float = Field(default=60.0, ge=10.0, le=1800.0, description="Timeout for LLM requests in seconds") + llm_stream_timeout_seconds: float = Field(default=60.0, ge=10.0, le=1800.0, description="Timeout for LLM streaming requests in seconds") + @property def letta_pg_uri(self) -> str: if self.pg_uri: diff --git a/tests/test_utils.py b/tests/test_utils.py index 214dfcbb..8733039b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -249,7 +249,9 @@ def test_coerce_dict_args_unsupported_complex_annotation(): annotations = {"f": "CustomClass[int]"} function_args = {"f": "CustomClass(42)"} - with pytest.raises(ValueError, match="Failed to coerce argument 'f' to CustomClass\[int\]: Unsupported annotation: CustomClass\[int\]"): + with pytest.raises( + ValueError, match=r"Failed to coerce argument 'f' to CustomClass\[int\]: Unsupported annotation: CustomClass\[int\]" + ): coerce_dict_args_by_annotations(function_args, annotations)