feat: support filtering out approval messages for gemini (#2833)

This commit is contained in:
cthomas
2025-09-11 00:26:24 -07:00
committed by GitHub
parent 636fb52d87
commit 694be7b136
3 changed files with 21 additions and 4 deletions

View File

@@ -850,7 +850,7 @@ class LettaAgentV2(BaseAgentV2):
tool_call_messages = create_letta_messages_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
function_name="",
function_name=tool_call.function.name,
function_arguments={},
tool_execution_result=ToolExecutionResult(status="error"),
tool_call_id=tool_call_id,

View File

@@ -272,7 +272,7 @@ class GoogleVertexClient(LLMClientBase):
tool_names = []
contents = self.add_dummy_model_messages(
[m.to_google_ai_dict() for m in messages],
PydanticMessage.to_google_dicts_from_list(messages),
)
request_data = {

View File

@@ -1027,10 +1027,13 @@ class Message(BaseMessage):
result = [m for m in result if m is not None]
return result
def to_google_ai_dict(self, put_inner_thoughts_in_kwargs: bool = True) -> dict:
def to_google_dict(self, put_inner_thoughts_in_kwargs: bool = True) -> dict | None:
"""
Go from Message class to Google AI REST message object
"""
if self.role == "approval" and self.tool_calls is None:
return None
# type Content: https://ai.google.dev/api/rest/v1/Content / https://ai.google.dev/api/rest/v1beta/Content
# parts[]: Part
# role: str ('user' or 'model')
@@ -1076,7 +1079,7 @@ class Message(BaseMessage):
"parts": content_parts,
}
elif self.role == "assistant":
elif self.role == "assistant" or self.role == "approval":
assert self.tool_calls is not None or text_content is not None
google_ai_message = {
"role": "model", # NOTE: different
@@ -1164,6 +1167,20 @@ class Message(BaseMessage):
return google_ai_message
@staticmethod
def to_google_dicts_from_list(
messages: List[Message],
put_inner_thoughts_in_kwargs: bool = True,
):
result = [
m.to_google_dict(
put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs,
)
for m in messages
]
result = [m for m in result if m is not None]
return result
@staticmethod
def generate_otid_from_id(message_id: str, index: int) -> str:
"""