diff --git a/letta/agent.py b/letta/agent.py index 79c5ac3c..32595b2a 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -384,6 +384,7 @@ class Agent(BaseAgent): delay = min(backoff_factor * (2 ** (attempt - 1)), max_delay) warnings.warn(f"Attempt {attempt} failed: {ve}. Retrying in {delay} seconds...") time.sleep(delay) + continue except Exception as e: # For non-retryable errors, exit immediately @@ -395,6 +396,7 @@ class Agent(BaseAgent): # trigger summarization log_telemetry(self.logger, "_get_ai_reply summarize_messages_inplace") self.summarize_messages_inplace() + # return the response return response diff --git a/letta/llm_api/google_ai_client.py b/letta/llm_api/google_ai_client.py index 430550be..c927ed3f 100644 --- a/letta/llm_api/google_ai_client.py +++ b/letta/llm_api/google_ai_client.py @@ -1,3 +1,4 @@ +import json import uuid from typing import List, Optional, Tuple @@ -11,6 +12,7 @@ from letta.llm_api.helpers import make_post_request from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.json_parser import clean_json_string_extra_backslash from letta.local_llm.utils import count_tokens +from letta.log import get_logger from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_request import Tool @@ -18,6 +20,8 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.settings import model_settings from letta.utils import get_tool_call_id +logger = get_logger(__name__) + class GoogleAIClient(LLMClientBase): @@ -25,6 +29,8 @@ class GoogleAIClient(LLMClientBase): """ Performs underlying request to llm and returns raw response. """ + # print("[google_ai request]", json.dumps(request_data, indent=2)) + url, headers = get_gemini_endpoint_and_headers( base_url=str(self.llm_config.model_endpoint), model=self.llm_config.model, @@ -46,9 +52,10 @@ class GoogleAIClient(LLMClientBase): """ if tools: tools = [{"type": "function", "function": f} for f in tools] - tools = self.convert_tools_to_google_ai_format( - [Tool(**t) for t in tools], - ) + tool_objs = [Tool(**t) for t in tools] + tool_names = [t.function.name for t in tool_objs] + # Convert to the exact payload style Google expects + tools = self.convert_tools_to_google_ai_format(tool_objs) contents = self.add_dummy_model_messages( [m.to_google_ai_dict() for m in messages], ) @@ -67,6 +74,8 @@ class GoogleAIClient(LLMClientBase): function_calling_config=FunctionCallingConfig( # ANY mode forces the model to predict only function calls mode=FunctionCallingConfigMode.ANY, + # Provide the list of tools (though empty should also work, it seems not to) + allowed_function_names=tool_names, ) ) request_data["tool_config"] = tool_config.model_dump() @@ -101,6 +110,8 @@ class GoogleAIClient(LLMClientBase): } } """ + # print("[google_ai response]", json.dumps(response_data, indent=2)) + try: choices = [] index = 0 @@ -111,6 +122,17 @@ class GoogleAIClient(LLMClientBase): assert role == "model", f"Unknown role in response: {role}" parts = content["parts"] + + # NOTE: we aren't properly supported multi-parts here anyways (we're just appending choices), + # so let's disable it for now + + # NOTE(Apr 9, 2025): there's a very strange bug on 2.5 where the response has a part with broken text + # {'candidates': [{'content': {'parts': [{'functionCall': {'name': 'send_message', 'args': {'request_heartbeat': False, 'message': 'Hello! How can I make your day better?', 'inner_thoughts': 'User has initiated contact. Sending a greeting.'}}}], 'role': 'model'}, 'finishReason': 'STOP', 'avgLogprobs': -0.25891534213362066}], 'usageMetadata': {'promptTokenCount': 2493, 'candidatesTokenCount': 29, 'totalTokenCount': 2522, 'promptTokensDetails': [{'modality': 'TEXT', 'tokenCount': 2493}], 'candidatesTokensDetails': [{'modality': 'TEXT', 'tokenCount': 29}]}, 'modelVersion': 'gemini-1.5-pro-002'} + # To patch this, if we have multiple parts we can take the last one + if len(parts) > 1: + logger.warning(f"Unexpected multiple parts in response from Google AI: {parts}") + parts = [parts[-1]] + # TODO support parts / multimodal # TODO support parallel tool calling natively # TODO Alternative here is to throw away everything else except for the first part @@ -201,10 +223,22 @@ class GoogleAIClient(LLMClientBase): # "totalTokenCount": 36 # } if "usageMetadata" in response_data: + usage_data = response_data["usageMetadata"] + if "promptTokenCount" not in usage_data: + raise ValueError(f"promptTokenCount not found in usageMetadata:\n{json.dumps(usage_data, indent=2)}") + if "totalTokenCount" not in usage_data: + raise ValueError(f"totalTokenCount not found in usageMetadata:\n{json.dumps(usage_data, indent=2)}") + if "candidatesTokenCount" not in usage_data: + raise ValueError(f"candidatesTokenCount not found in usageMetadata:\n{json.dumps(usage_data, indent=2)}") + + prompt_tokens = usage_data["promptTokenCount"] + completion_tokens = usage_data["candidatesTokenCount"] + total_tokens = usage_data["totalTokenCount"] + usage = UsageStatistics( - prompt_tokens=response_data["usageMetadata"]["promptTokenCount"], - completion_tokens=response_data["usageMetadata"]["candidatesTokenCount"], - total_tokens=response_data["usageMetadata"]["totalTokenCount"], + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, ) else: # Count it ourselves diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index 937dbe22..3bd4eb95 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -49,10 +49,13 @@ class GoogleVertexClient(GoogleAIClient): request_data["config"] = request_data.pop("generation_config") request_data["config"]["tools"] = request_data.pop("tools") + tool_names = [t["name"] for t in tools] tool_config = ToolConfig( function_calling_config=FunctionCallingConfig( # ANY mode forces the model to predict only function calls mode=FunctionCallingConfigMode.ANY, + # Provide the list of tools (though empty should also work, it seems not to) + allowed_function_names=tool_names, ) ) request_data["config"]["tool_config"] = tool_config.model_dump() @@ -88,6 +91,8 @@ class GoogleVertexClient(GoogleAIClient): } } """ + # print(response_data) + response = GenerateContentResponse(**response_data) try: choices = [] @@ -99,6 +104,17 @@ class GoogleVertexClient(GoogleAIClient): assert role == "model", f"Unknown role in response: {role}" parts = content.parts + + # NOTE: we aren't properly supported multi-parts here anyways (we're just appending choices), + # so let's disable it for now + + # NOTE(Apr 9, 2025): there's a very strange bug on 2.5 where the response has a part with broken text + # {'candidates': [{'content': {'parts': [{'functionCall': {'name': 'send_message', 'args': {'request_heartbeat': False, 'message': 'Hello! How can I make your day better?', 'inner_thoughts': 'User has initiated contact. Sending a greeting.'}}}], 'role': 'model'}, 'finishReason': 'STOP', 'avgLogprobs': -0.25891534213362066}], 'usageMetadata': {'promptTokenCount': 2493, 'candidatesTokenCount': 29, 'totalTokenCount': 2522, 'promptTokensDetails': [{'modality': 'TEXT', 'tokenCount': 2493}], 'candidatesTokensDetails': [{'modality': 'TEXT', 'tokenCount': 29}]}, 'modelVersion': 'gemini-1.5-pro-002'} + # To patch this, if we have multiple parts we can take the last one + if len(parts) > 1: + logger.warning(f"Unexpected multiple parts in response from Google AI: {parts}") + parts = [parts[-1]] + # TODO support parts / multimodal # TODO support parallel tool calling natively # TODO Alternative here is to throw away everything else except for the first part diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index 3ef30188..37da42a9 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -29,6 +29,7 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionChunkRes from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser from letta.streaming_interface import AgentChunkStreamingInterface from letta.streaming_utils import FunctionArgumentsStreamHandler, JSONInnerThoughtsExtractor +from letta.utils import parse_json # TODO strip from code / deprecate @@ -408,7 +409,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # if self.expect_reasoning_content_buffer is not None: # try: # # NOTE: this is hardcoded for our DeepSeek API integration - # json_reasoning_content = json.loads(self.expect_reasoning_content_buffer) + # json_reasoning_content = parse_json(self.expect_reasoning_content_buffer) # if "name" in json_reasoning_content: # self._push_to_buffer( @@ -528,7 +529,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): try: # NOTE: this is hardcoded for our DeepSeek API integration - json_reasoning_content = json.loads(self.expect_reasoning_content_buffer) + json_reasoning_content = parse_json(self.expect_reasoning_content_buffer) print(f"json_reasoning_content: {json_reasoning_content}") processed_chunk = ToolCallMessage( @@ -1188,7 +1189,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # "date": "2024-06-22T23:04:32.141923+00:00" # } try: - func_args = json.loads(function_call.function.arguments) + func_args = parse_json(function_call.function.arguments) except: func_args = function_call.function.arguments # processed_chunk = { @@ -1224,7 +1225,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): else: try: - func_args = json.loads(function_call.function.arguments) + func_args = parse_json(function_call.function.arguments) except: warnings.warn(f"Failed to parse function arguments: {function_call.function.arguments}") func_args = {} diff --git a/letta/utils.py b/letta/utils.py index 57ff86d0..fbb926b8 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -828,7 +828,7 @@ def parse_json(string) -> dict: raise ValueError(f"JSON from string input ({string}) is not a dictionary (type {type(result)}): {result}") return result except Exception as e: - print(f"Error parsing json with json package: {e}") + print(f"Error parsing json with json package, falling back to demjson: {e}") try: result = demjson.decode(string) @@ -836,7 +836,7 @@ def parse_json(string) -> dict: raise ValueError(f"JSON from string input ({string}) is not a dictionary (type {type(result)}): {result}") return result except demjson.JSONDecodeError as e: - print(f"Error parsing json with demjson package: {e}") + print(f"Error parsing json with demjson package (fatal): {e}") raise e