fix: patch gemini 2.5 pro (#1643)
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
@@ -384,6 +384,7 @@ class Agent(BaseAgent):
|
||||
delay = min(backoff_factor * (2 ** (attempt - 1)), max_delay)
|
||||
warnings.warn(f"Attempt {attempt} failed: {ve}. Retrying in {delay} seconds...")
|
||||
time.sleep(delay)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
# For non-retryable errors, exit immediately
|
||||
@@ -395,6 +396,7 @@ class Agent(BaseAgent):
|
||||
# trigger summarization
|
||||
log_telemetry(self.logger, "_get_ai_reply summarize_messages_inplace")
|
||||
self.summarize_messages_inplace()
|
||||
|
||||
# return the response
|
||||
return response
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
@@ -11,6 +12,7 @@ from letta.llm_api.helpers import make_post_request
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
||||
from letta.local_llm.utils import count_tokens
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_request import Tool
|
||||
@@ -18,6 +20,8 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.settings import model_settings
|
||||
from letta.utils import get_tool_call_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GoogleAIClient(LLMClientBase):
|
||||
|
||||
@@ -25,6 +29,8 @@ class GoogleAIClient(LLMClientBase):
|
||||
"""
|
||||
Performs underlying request to llm and returns raw response.
|
||||
"""
|
||||
# print("[google_ai request]", json.dumps(request_data, indent=2))
|
||||
|
||||
url, headers = get_gemini_endpoint_and_headers(
|
||||
base_url=str(self.llm_config.model_endpoint),
|
||||
model=self.llm_config.model,
|
||||
@@ -46,9 +52,10 @@ class GoogleAIClient(LLMClientBase):
|
||||
"""
|
||||
if tools:
|
||||
tools = [{"type": "function", "function": f} for f in tools]
|
||||
tools = self.convert_tools_to_google_ai_format(
|
||||
[Tool(**t) for t in tools],
|
||||
)
|
||||
tool_objs = [Tool(**t) for t in tools]
|
||||
tool_names = [t.function.name for t in tool_objs]
|
||||
# Convert to the exact payload style Google expects
|
||||
tools = self.convert_tools_to_google_ai_format(tool_objs)
|
||||
contents = self.add_dummy_model_messages(
|
||||
[m.to_google_ai_dict() for m in messages],
|
||||
)
|
||||
@@ -67,6 +74,8 @@ class GoogleAIClient(LLMClientBase):
|
||||
function_calling_config=FunctionCallingConfig(
|
||||
# ANY mode forces the model to predict only function calls
|
||||
mode=FunctionCallingConfigMode.ANY,
|
||||
# Provide the list of tools (though empty should also work, it seems not to)
|
||||
allowed_function_names=tool_names,
|
||||
)
|
||||
)
|
||||
request_data["tool_config"] = tool_config.model_dump()
|
||||
@@ -101,6 +110,8 @@ class GoogleAIClient(LLMClientBase):
|
||||
}
|
||||
}
|
||||
"""
|
||||
# print("[google_ai response]", json.dumps(response_data, indent=2))
|
||||
|
||||
try:
|
||||
choices = []
|
||||
index = 0
|
||||
@@ -111,6 +122,17 @@ class GoogleAIClient(LLMClientBase):
|
||||
assert role == "model", f"Unknown role in response: {role}"
|
||||
|
||||
parts = content["parts"]
|
||||
|
||||
# NOTE: we aren't properly supported multi-parts here anyways (we're just appending choices),
|
||||
# so let's disable it for now
|
||||
|
||||
# NOTE(Apr 9, 2025): there's a very strange bug on 2.5 where the response has a part with broken text
|
||||
# {'candidates': [{'content': {'parts': [{'functionCall': {'name': 'send_message', 'args': {'request_heartbeat': False, 'message': 'Hello! How can I make your day better?', 'inner_thoughts': 'User has initiated contact. Sending a greeting.'}}}], 'role': 'model'}, 'finishReason': 'STOP', 'avgLogprobs': -0.25891534213362066}], 'usageMetadata': {'promptTokenCount': 2493, 'candidatesTokenCount': 29, 'totalTokenCount': 2522, 'promptTokensDetails': [{'modality': 'TEXT', 'tokenCount': 2493}], 'candidatesTokensDetails': [{'modality': 'TEXT', 'tokenCount': 29}]}, 'modelVersion': 'gemini-1.5-pro-002'}
|
||||
# To patch this, if we have multiple parts we can take the last one
|
||||
if len(parts) > 1:
|
||||
logger.warning(f"Unexpected multiple parts in response from Google AI: {parts}")
|
||||
parts = [parts[-1]]
|
||||
|
||||
# TODO support parts / multimodal
|
||||
# TODO support parallel tool calling natively
|
||||
# TODO Alternative here is to throw away everything else except for the first part
|
||||
@@ -201,10 +223,22 @@ class GoogleAIClient(LLMClientBase):
|
||||
# "totalTokenCount": 36
|
||||
# }
|
||||
if "usageMetadata" in response_data:
|
||||
usage_data = response_data["usageMetadata"]
|
||||
if "promptTokenCount" not in usage_data:
|
||||
raise ValueError(f"promptTokenCount not found in usageMetadata:\n{json.dumps(usage_data, indent=2)}")
|
||||
if "totalTokenCount" not in usage_data:
|
||||
raise ValueError(f"totalTokenCount not found in usageMetadata:\n{json.dumps(usage_data, indent=2)}")
|
||||
if "candidatesTokenCount" not in usage_data:
|
||||
raise ValueError(f"candidatesTokenCount not found in usageMetadata:\n{json.dumps(usage_data, indent=2)}")
|
||||
|
||||
prompt_tokens = usage_data["promptTokenCount"]
|
||||
completion_tokens = usage_data["candidatesTokenCount"]
|
||||
total_tokens = usage_data["totalTokenCount"]
|
||||
|
||||
usage = UsageStatistics(
|
||||
prompt_tokens=response_data["usageMetadata"]["promptTokenCount"],
|
||||
completion_tokens=response_data["usageMetadata"]["candidatesTokenCount"],
|
||||
total_tokens=response_data["usageMetadata"]["totalTokenCount"],
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
else:
|
||||
# Count it ourselves
|
||||
|
||||
@@ -49,10 +49,13 @@ class GoogleVertexClient(GoogleAIClient):
|
||||
request_data["config"] = request_data.pop("generation_config")
|
||||
request_data["config"]["tools"] = request_data.pop("tools")
|
||||
|
||||
tool_names = [t["name"] for t in tools]
|
||||
tool_config = ToolConfig(
|
||||
function_calling_config=FunctionCallingConfig(
|
||||
# ANY mode forces the model to predict only function calls
|
||||
mode=FunctionCallingConfigMode.ANY,
|
||||
# Provide the list of tools (though empty should also work, it seems not to)
|
||||
allowed_function_names=tool_names,
|
||||
)
|
||||
)
|
||||
request_data["config"]["tool_config"] = tool_config.model_dump()
|
||||
@@ -88,6 +91,8 @@ class GoogleVertexClient(GoogleAIClient):
|
||||
}
|
||||
}
|
||||
"""
|
||||
# print(response_data)
|
||||
|
||||
response = GenerateContentResponse(**response_data)
|
||||
try:
|
||||
choices = []
|
||||
@@ -99,6 +104,17 @@ class GoogleVertexClient(GoogleAIClient):
|
||||
assert role == "model", f"Unknown role in response: {role}"
|
||||
|
||||
parts = content.parts
|
||||
|
||||
# NOTE: we aren't properly supported multi-parts here anyways (we're just appending choices),
|
||||
# so let's disable it for now
|
||||
|
||||
# NOTE(Apr 9, 2025): there's a very strange bug on 2.5 where the response has a part with broken text
|
||||
# {'candidates': [{'content': {'parts': [{'functionCall': {'name': 'send_message', 'args': {'request_heartbeat': False, 'message': 'Hello! How can I make your day better?', 'inner_thoughts': 'User has initiated contact. Sending a greeting.'}}}], 'role': 'model'}, 'finishReason': 'STOP', 'avgLogprobs': -0.25891534213362066}], 'usageMetadata': {'promptTokenCount': 2493, 'candidatesTokenCount': 29, 'totalTokenCount': 2522, 'promptTokensDetails': [{'modality': 'TEXT', 'tokenCount': 2493}], 'candidatesTokensDetails': [{'modality': 'TEXT', 'tokenCount': 29}]}, 'modelVersion': 'gemini-1.5-pro-002'}
|
||||
# To patch this, if we have multiple parts we can take the last one
|
||||
if len(parts) > 1:
|
||||
logger.warning(f"Unexpected multiple parts in response from Google AI: {parts}")
|
||||
parts = [parts[-1]]
|
||||
|
||||
# TODO support parts / multimodal
|
||||
# TODO support parallel tool calling natively
|
||||
# TODO Alternative here is to throw away everything else except for the first part
|
||||
|
||||
@@ -29,6 +29,7 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionChunkRes
|
||||
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface
|
||||
from letta.streaming_utils import FunctionArgumentsStreamHandler, JSONInnerThoughtsExtractor
|
||||
from letta.utils import parse_json
|
||||
|
||||
|
||||
# TODO strip from code / deprecate
|
||||
@@ -408,7 +409,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# if self.expect_reasoning_content_buffer is not None:
|
||||
# try:
|
||||
# # NOTE: this is hardcoded for our DeepSeek API integration
|
||||
# json_reasoning_content = json.loads(self.expect_reasoning_content_buffer)
|
||||
# json_reasoning_content = parse_json(self.expect_reasoning_content_buffer)
|
||||
|
||||
# if "name" in json_reasoning_content:
|
||||
# self._push_to_buffer(
|
||||
@@ -528,7 +529,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
|
||||
try:
|
||||
# NOTE: this is hardcoded for our DeepSeek API integration
|
||||
json_reasoning_content = json.loads(self.expect_reasoning_content_buffer)
|
||||
json_reasoning_content = parse_json(self.expect_reasoning_content_buffer)
|
||||
print(f"json_reasoning_content: {json_reasoning_content}")
|
||||
|
||||
processed_chunk = ToolCallMessage(
|
||||
@@ -1188,7 +1189,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# "date": "2024-06-22T23:04:32.141923+00:00"
|
||||
# }
|
||||
try:
|
||||
func_args = json.loads(function_call.function.arguments)
|
||||
func_args = parse_json(function_call.function.arguments)
|
||||
except:
|
||||
func_args = function_call.function.arguments
|
||||
# processed_chunk = {
|
||||
@@ -1224,7 +1225,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
else:
|
||||
|
||||
try:
|
||||
func_args = json.loads(function_call.function.arguments)
|
||||
func_args = parse_json(function_call.function.arguments)
|
||||
except:
|
||||
warnings.warn(f"Failed to parse function arguments: {function_call.function.arguments}")
|
||||
func_args = {}
|
||||
|
||||
@@ -828,7 +828,7 @@ def parse_json(string) -> dict:
|
||||
raise ValueError(f"JSON from string input ({string}) is not a dictionary (type {type(result)}): {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"Error parsing json with json package: {e}")
|
||||
print(f"Error parsing json with json package, falling back to demjson: {e}")
|
||||
|
||||
try:
|
||||
result = demjson.decode(string)
|
||||
@@ -836,7 +836,7 @@ def parse_json(string) -> dict:
|
||||
raise ValueError(f"JSON from string input ({string}) is not a dictionary (type {type(result)}): {result}")
|
||||
return result
|
||||
except demjson.JSONDecodeError as e:
|
||||
print(f"Error parsing json with demjson package: {e}")
|
||||
print(f"Error parsing json with demjson package (fatal): {e}")
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user