fix: patch gemini 2.5 pro (#1643)

Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
Charles Packer
2025-04-09 18:35:53 -07:00
committed by GitHub
parent 8f5a43b886
commit ee8e095b69
5 changed files with 65 additions and 12 deletions

View File

@@ -384,6 +384,7 @@ class Agent(BaseAgent):
delay = min(backoff_factor * (2 ** (attempt - 1)), max_delay)
warnings.warn(f"Attempt {attempt} failed: {ve}. Retrying in {delay} seconds...")
time.sleep(delay)
continue
except Exception as e:
# For non-retryable errors, exit immediately
@@ -395,6 +396,7 @@ class Agent(BaseAgent):
# trigger summarization
log_telemetry(self.logger, "_get_ai_reply summarize_messages_inplace")
self.summarize_messages_inplace()
# return the response
return response

View File

@@ -1,3 +1,4 @@
import json
import uuid
from typing import List, Optional, Tuple
@@ -11,6 +12,7 @@ from letta.llm_api.helpers import make_post_request
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.json_parser import clean_json_string_extra_backslash
from letta.local_llm.utils import count_tokens
from letta.log import get_logger
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import Tool
@@ -18,6 +20,8 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.settings import model_settings
from letta.utils import get_tool_call_id
logger = get_logger(__name__)
class GoogleAIClient(LLMClientBase):
@@ -25,6 +29,8 @@ class GoogleAIClient(LLMClientBase):
"""
Performs underlying request to llm and returns raw response.
"""
# print("[google_ai request]", json.dumps(request_data, indent=2))
url, headers = get_gemini_endpoint_and_headers(
base_url=str(self.llm_config.model_endpoint),
model=self.llm_config.model,
@@ -46,9 +52,10 @@ class GoogleAIClient(LLMClientBase):
"""
if tools:
tools = [{"type": "function", "function": f} for f in tools]
tools = self.convert_tools_to_google_ai_format(
[Tool(**t) for t in tools],
)
tool_objs = [Tool(**t) for t in tools]
tool_names = [t.function.name for t in tool_objs]
# Convert to the exact payload style Google expects
tools = self.convert_tools_to_google_ai_format(tool_objs)
contents = self.add_dummy_model_messages(
[m.to_google_ai_dict() for m in messages],
)
@@ -67,6 +74,8 @@ class GoogleAIClient(LLMClientBase):
function_calling_config=FunctionCallingConfig(
# ANY mode forces the model to predict only function calls
mode=FunctionCallingConfigMode.ANY,
# Provide the list of tools (though empty should also work, it seems not to)
allowed_function_names=tool_names,
)
)
request_data["tool_config"] = tool_config.model_dump()
@@ -101,6 +110,8 @@ class GoogleAIClient(LLMClientBase):
}
}
"""
# print("[google_ai response]", json.dumps(response_data, indent=2))
try:
choices = []
index = 0
@@ -111,6 +122,17 @@ class GoogleAIClient(LLMClientBase):
assert role == "model", f"Unknown role in response: {role}"
parts = content["parts"]
# NOTE: we aren't properly supported multi-parts here anyways (we're just appending choices),
# so let's disable it for now
# NOTE(Apr 9, 2025): there's a very strange bug on 2.5 where the response has a part with broken text
# {'candidates': [{'content': {'parts': [{'functionCall': {'name': 'send_message', 'args': {'request_heartbeat': False, 'message': 'Hello! How can I make your day better?', 'inner_thoughts': 'User has initiated contact. Sending a greeting.'}}}], 'role': 'model'}, 'finishReason': 'STOP', 'avgLogprobs': -0.25891534213362066}], 'usageMetadata': {'promptTokenCount': 2493, 'candidatesTokenCount': 29, 'totalTokenCount': 2522, 'promptTokensDetails': [{'modality': 'TEXT', 'tokenCount': 2493}], 'candidatesTokensDetails': [{'modality': 'TEXT', 'tokenCount': 29}]}, 'modelVersion': 'gemini-1.5-pro-002'}
# To patch this, if we have multiple parts we can take the last one
if len(parts) > 1:
logger.warning(f"Unexpected multiple parts in response from Google AI: {parts}")
parts = [parts[-1]]
# TODO support parts / multimodal
# TODO support parallel tool calling natively
# TODO Alternative here is to throw away everything else except for the first part
@@ -201,10 +223,22 @@ class GoogleAIClient(LLMClientBase):
# "totalTokenCount": 36
# }
if "usageMetadata" in response_data:
usage_data = response_data["usageMetadata"]
if "promptTokenCount" not in usage_data:
raise ValueError(f"promptTokenCount not found in usageMetadata:\n{json.dumps(usage_data, indent=2)}")
if "totalTokenCount" not in usage_data:
raise ValueError(f"totalTokenCount not found in usageMetadata:\n{json.dumps(usage_data, indent=2)}")
if "candidatesTokenCount" not in usage_data:
raise ValueError(f"candidatesTokenCount not found in usageMetadata:\n{json.dumps(usage_data, indent=2)}")
prompt_tokens = usage_data["promptTokenCount"]
completion_tokens = usage_data["candidatesTokenCount"]
total_tokens = usage_data["totalTokenCount"]
usage = UsageStatistics(
prompt_tokens=response_data["usageMetadata"]["promptTokenCount"],
completion_tokens=response_data["usageMetadata"]["candidatesTokenCount"],
total_tokens=response_data["usageMetadata"]["totalTokenCount"],
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
else:
# Count it ourselves

View File

@@ -49,10 +49,13 @@ class GoogleVertexClient(GoogleAIClient):
request_data["config"] = request_data.pop("generation_config")
request_data["config"]["tools"] = request_data.pop("tools")
tool_names = [t["name"] for t in tools]
tool_config = ToolConfig(
function_calling_config=FunctionCallingConfig(
# ANY mode forces the model to predict only function calls
mode=FunctionCallingConfigMode.ANY,
# Provide the list of tools (though empty should also work, it seems not to)
allowed_function_names=tool_names,
)
)
request_data["config"]["tool_config"] = tool_config.model_dump()
@@ -88,6 +91,8 @@ class GoogleVertexClient(GoogleAIClient):
}
}
"""
# print(response_data)
response = GenerateContentResponse(**response_data)
try:
choices = []
@@ -99,6 +104,17 @@ class GoogleVertexClient(GoogleAIClient):
assert role == "model", f"Unknown role in response: {role}"
parts = content.parts
# NOTE: we aren't properly supported multi-parts here anyways (we're just appending choices),
# so let's disable it for now
# NOTE(Apr 9, 2025): there's a very strange bug on 2.5 where the response has a part with broken text
# {'candidates': [{'content': {'parts': [{'functionCall': {'name': 'send_message', 'args': {'request_heartbeat': False, 'message': 'Hello! How can I make your day better?', 'inner_thoughts': 'User has initiated contact. Sending a greeting.'}}}], 'role': 'model'}, 'finishReason': 'STOP', 'avgLogprobs': -0.25891534213362066}], 'usageMetadata': {'promptTokenCount': 2493, 'candidatesTokenCount': 29, 'totalTokenCount': 2522, 'promptTokensDetails': [{'modality': 'TEXT', 'tokenCount': 2493}], 'candidatesTokensDetails': [{'modality': 'TEXT', 'tokenCount': 29}]}, 'modelVersion': 'gemini-1.5-pro-002'}
# To patch this, if we have multiple parts we can take the last one
if len(parts) > 1:
logger.warning(f"Unexpected multiple parts in response from Google AI: {parts}")
parts = [parts[-1]]
# TODO support parts / multimodal
# TODO support parallel tool calling natively
# TODO Alternative here is to throw away everything else except for the first part

View File

@@ -29,6 +29,7 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionChunkRes
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
from letta.streaming_interface import AgentChunkStreamingInterface
from letta.streaming_utils import FunctionArgumentsStreamHandler, JSONInnerThoughtsExtractor
from letta.utils import parse_json
# TODO strip from code / deprecate
@@ -408,7 +409,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# if self.expect_reasoning_content_buffer is not None:
# try:
# # NOTE: this is hardcoded for our DeepSeek API integration
# json_reasoning_content = json.loads(self.expect_reasoning_content_buffer)
# json_reasoning_content = parse_json(self.expect_reasoning_content_buffer)
# if "name" in json_reasoning_content:
# self._push_to_buffer(
@@ -528,7 +529,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
try:
# NOTE: this is hardcoded for our DeepSeek API integration
json_reasoning_content = json.loads(self.expect_reasoning_content_buffer)
json_reasoning_content = parse_json(self.expect_reasoning_content_buffer)
print(f"json_reasoning_content: {json_reasoning_content}")
processed_chunk = ToolCallMessage(
@@ -1188,7 +1189,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# "date": "2024-06-22T23:04:32.141923+00:00"
# }
try:
func_args = json.loads(function_call.function.arguments)
func_args = parse_json(function_call.function.arguments)
except:
func_args = function_call.function.arguments
# processed_chunk = {
@@ -1224,7 +1225,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
else:
try:
func_args = json.loads(function_call.function.arguments)
func_args = parse_json(function_call.function.arguments)
except:
warnings.warn(f"Failed to parse function arguments: {function_call.function.arguments}")
func_args = {}

View File

@@ -828,7 +828,7 @@ def parse_json(string) -> dict:
raise ValueError(f"JSON from string input ({string}) is not a dictionary (type {type(result)}): {result}")
return result
except Exception as e:
print(f"Error parsing json with json package: {e}")
print(f"Error parsing json with json package, falling back to demjson: {e}")
try:
result = demjson.decode(string)
@@ -836,7 +836,7 @@ def parse_json(string) -> dict:
raise ValueError(f"JSON from string input ({string}) is not a dictionary (type {type(result)}): {result}")
return result
except demjson.JSONDecodeError as e:
print(f"Error parsing json with demjson package: {e}")
print(f"Error parsing json with demjson package (fatal): {e}")
raise e