feat: add gemini streaming to new agent loop (#5109)
* feat: add gemini streaming to new agent loop * add google as required dependency * support storing all content parts * remove extra google references
This commit is contained in:
2
.github/workflows/core-unit-sqlite-test.yaml
vendored
2
.github/workflows/core-unit-sqlite-test.yaml
vendored
@@ -25,7 +25,7 @@ jobs:
|
||||
apps/core/**
|
||||
.github/workflows/reusable-test-workflow.yml
|
||||
.github/workflows/core-unit-sqlite-test.yml
|
||||
install-args: '--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox --extra google --extra sqlite'
|
||||
install-args: '--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox --extra sqlite'
|
||||
timeout-minutes: 15
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
|
||||
|
||||
2
.github/workflows/core-unit-test.yml
vendored
2
.github/workflows/core-unit-test.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
**
|
||||
.github/workflows/reusable-test-workflow.yml
|
||||
.github/workflows/core-unit-test.yml
|
||||
install-args: '--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox --extra google'
|
||||
install-args: '--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox'
|
||||
timeout-minutes: 15
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
matrix-strategy: |
|
||||
|
||||
2
.github/workflows/model-sweep.yaml
vendored
2
.github/workflows/model-sweep.yaml
vendored
@@ -61,7 +61,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: uv sync --extra dev --extra postgres --extra external-tools --extra cloud-tool-sandbox --extra google
|
||||
run: uv sync --extra dev --extra postgres --extra external-tools --extra cloud-tool-sandbox
|
||||
- name: Migrate database
|
||||
env:
|
||||
LETTA_PG_PORT: 5432
|
||||
|
||||
@@ -25,7 +25,7 @@ jobs:
|
||||
**
|
||||
.github/workflows/reusable-test-workflow.yml
|
||||
.github/workflows/send-message-integration-tests.yml
|
||||
install-args: '--extra dev --extra postgres --extra external-tools --extra cloud-tool-sandbox --extra google --extra redis'
|
||||
install-args: '--extra dev --extra postgres --extra external-tools --extra cloud-tool-sandbox --extra redis'
|
||||
timeout-minutes: 15
|
||||
runner: '["self-hosted", "medium"]'
|
||||
ref: ${{ github.event.pull_request.head.sha || github.sha }}
|
||||
|
||||
2
.github/workflows/test-lmstudio.yml
vendored
2
.github/workflows/test-lmstudio.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
with:
|
||||
test-type: "integration"
|
||||
is-external-pr: ${{ github.event_name == 'pull_request_target' && !contains(github.event.pull_request.labels.*.name, 'safe to test') }}
|
||||
install-args: "--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox --extra google"
|
||||
install-args: "--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox"
|
||||
test-command: "uv run pytest -svv tests/"
|
||||
timeout-minutes: 60
|
||||
runner: '["self-hosted", "gpu", "lmstudio"]'
|
||||
|
||||
2
.github/workflows/test-ollama.yml
vendored
2
.github/workflows/test-ollama.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
with:
|
||||
test-type: "integration"
|
||||
is-external-pr: ${{ github.event_name == 'pull_request_target' && !contains(github.event.pull_request.labels.*.name, 'safe to test') }}
|
||||
install-args: "--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox --extra google"
|
||||
install-args: "--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox"
|
||||
test-command: "uv run --frozen pytest -svv tests/"
|
||||
timeout-minutes: 60
|
||||
runner: '["self-hosted", "gpu", "ollama"]'
|
||||
|
||||
2
.github/workflows/test-vllm.yml
vendored
2
.github/workflows/test-vllm.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
with:
|
||||
test-type: "integration"
|
||||
is-external-pr: ${{ github.event_name == 'pull_request_target' && !contains(github.event.pull_request.labels.*.name, 'safe to test') }}
|
||||
install-args: "--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox --extra google"
|
||||
install-args: "--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox"
|
||||
test-command: "uv run --frozen pytest -svv tests/"
|
||||
timeout-minutes: 60
|
||||
runner: '["self-hosted", "gpu", "vllm"]'
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import AsyncGenerator, List
|
||||
from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter
|
||||
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
|
||||
from letta.interfaces.anthropic_streaming_interface import SimpleAnthropicStreamingInterface
|
||||
from letta.interfaces.gemini_streaming_interface import SimpleGeminiStreamingInterface
|
||||
from letta.interfaces.openai_streaming_interface import SimpleOpenAIResponsesStreamingInterface, SimpleOpenAIStreamingInterface
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
@@ -78,6 +79,12 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
run_id=self.run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
elif self.llm_config.model_endpoint_type in [ProviderType.google_ai, ProviderType.google_vertex]:
|
||||
self.interface = SimpleGeminiStreamingInterface(
|
||||
requires_approval_tools=requires_approval_tools,
|
||||
run_id=self.run_id,
|
||||
step_id=step_id,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Streaming not supported for provider {self.llm_config.model_endpoint_type}")
|
||||
|
||||
|
||||
279
letta/interfaces/gemini_streaming_interface.py
Normal file
279
letta/interfaces/gemini_streaming_interface.py
Normal file
@@ -0,0 +1,279 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncIterator, List, Optional
|
||||
|
||||
from google.genai.types import (
|
||||
GenerateContentResponse,
|
||||
)
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.letta_message import (
|
||||
ApprovalRequestMessage,
|
||||
AssistantMessage,
|
||||
LettaMessage,
|
||||
ReasoningMessage,
|
||||
ToolCallDelta,
|
||||
ToolCallMessage,
|
||||
)
|
||||
from letta.schemas.letta_message_content import (
|
||||
ReasoningContent,
|
||||
TextContent,
|
||||
ToolCallContent,
|
||||
)
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
||||
from letta.utils import get_tool_call_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SimpleGeminiStreamingInterface:
|
||||
"""
|
||||
Encapsulates the logic for streaming responses from Gemini API:
|
||||
https://ai.google.dev/gemini-api/docs/text-generation#streaming-responses
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
requires_approval_tools: list = [],
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
):
|
||||
self.run_id = run_id
|
||||
self.step_id = step_id
|
||||
|
||||
# self.messages = messages
|
||||
# self.tools = tools
|
||||
self.requires_approval_tools = requires_approval_tools
|
||||
# ID responses used
|
||||
self.message_id = None
|
||||
|
||||
# In Gemini streaming, tool call comes all at once
|
||||
self.tool_call_id: str | None = None
|
||||
self.tool_call_name: str | None = None
|
||||
self.tool_call_args: dict | None = None # NOTE: Not a str!
|
||||
|
||||
# NOTE: signature only is included if tools are present
|
||||
self.thinking_signature: str | None = None
|
||||
|
||||
# Regular text content too
|
||||
self.text_content: str | None = None
|
||||
|
||||
# Premake IDs for database writes
|
||||
self.letta_message_id = Message.generate_id()
|
||||
# self.model = model
|
||||
|
||||
# Sadly, Gemini's encrypted reasoning logic forces us to store stream parts in state
|
||||
self.content_parts: List[ReasoningContent | TextContent | ToolCallContent] = []
|
||||
|
||||
def get_content(self) -> List[ReasoningContent | TextContent | ToolCallContent]:
|
||||
"""This is (unusually) in chunked format, instead of merged"""
|
||||
# for content in self.content_parts:
|
||||
# if isinstance(content, ReasoningContent):
|
||||
# # This assumes there is only one signature per turn
|
||||
# content.signature = self.thinking_signature
|
||||
return self.content_parts
|
||||
|
||||
def get_tool_call_object(self) -> ToolCall:
|
||||
"""Useful for agent loop"""
|
||||
if self.tool_call_id is None:
|
||||
raise ValueError("No tool call ID available")
|
||||
if self.tool_call_name is None:
|
||||
raise ValueError("No tool call name available")
|
||||
if self.tool_call_args is None:
|
||||
raise ValueError("No tool call arguments available")
|
||||
|
||||
# TODO use json_dumps?
|
||||
tool_call_args_str = json.dumps(self.tool_call_args)
|
||||
|
||||
return ToolCall(
|
||||
id=self.tool_call_id,
|
||||
type="function",
|
||||
function=FunctionCall(
|
||||
name=self.tool_call_name,
|
||||
arguments=tool_call_args_str,
|
||||
),
|
||||
)
|
||||
|
||||
async def process(
|
||||
self,
|
||||
stream: AsyncIterator[GenerateContentResponse],
|
||||
ttft_span: Optional["Span"] = None,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
"""
|
||||
Iterates over the Gemini stream, yielding SSE events.
|
||||
It also collects tokens and detects if a tool call is triggered.
|
||||
"""
|
||||
prev_message_type = None
|
||||
message_index = 0
|
||||
try:
|
||||
async for event in stream:
|
||||
try:
|
||||
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
|
||||
new_message_type = message.message_type
|
||||
if new_message_type != prev_message_type:
|
||||
if prev_message_type != None:
|
||||
message_index += 1
|
||||
prev_message_type = new_message_type
|
||||
yield message
|
||||
except asyncio.CancelledError as e:
|
||||
import traceback
|
||||
|
||||
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
|
||||
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
|
||||
new_message_type = message.message_type
|
||||
if new_message_type != prev_message_type:
|
||||
if prev_message_type != None:
|
||||
message_index += 1
|
||||
prev_message_type = new_message_type
|
||||
yield message
|
||||
|
||||
# Don't raise the exception here
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error("Error processing stream: %s\n%s", e, traceback.format_exc())
|
||||
if ttft_span:
|
||||
ttft_span.add_event(
|
||||
name="stop_reason",
|
||||
attributes={"stop_reason": StopReasonType.error.value, "error": str(e), "stacktrace": traceback.format_exc()},
|
||||
)
|
||||
yield LettaStopReason(stop_reason=StopReasonType.error)
|
||||
raise e
|
||||
finally:
|
||||
logger.info("GeminiStreamingInterface: Stream processing complete.")
|
||||
|
||||
async def _process_event(
|
||||
self,
|
||||
event: GenerateContentResponse,
|
||||
ttft_span: Optional["Span"] = None,
|
||||
prev_message_type: Optional[str] = None,
|
||||
message_index: int = 0,
|
||||
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
||||
# Every event has usage data + model info on it,
|
||||
# so we can continually extract
|
||||
self.model = event.model_version
|
||||
self.message_id = event.response_id
|
||||
usage_metadata = event.usage_metadata
|
||||
if usage_metadata:
|
||||
if usage_metadata.prompt_token_count:
|
||||
self.input_tokens = usage_metadata.prompt_token_count
|
||||
if usage_metadata.total_token_count:
|
||||
self.output_tokens = usage_metadata.total_token_count - usage_metadata.prompt_token_count
|
||||
|
||||
if not event.candidates or len(event.candidates) == 0:
|
||||
return
|
||||
else:
|
||||
# NOTE: should always be len 1
|
||||
candidate = event.candidates[0]
|
||||
|
||||
for part in candidate.content.parts:
|
||||
# NOTE: the thought signature often comes after the thought text, eg with the tool call
|
||||
if part.thought_signature:
|
||||
# NOTE: the thought_signature comes on the Part with the function_call
|
||||
thought_signature = part.thought_signature
|
||||
self.thinking_signature = thought_signature
|
||||
# TODO: support reasoning
|
||||
# yield ReasoningMessage(
|
||||
# id=self.letta_message_id,
|
||||
# date=datetime.now(timezone.utc).isoformat(),
|
||||
# otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
# source="reasoner_model",
|
||||
# reasoning="",
|
||||
# signature=base64.b64encode(thought_signature).decode('utf-8'),
|
||||
# )
|
||||
|
||||
# Thinking summary content part (bool means text is thought part)
|
||||
if part.thought:
|
||||
reasoning_summary = part.text
|
||||
self.thinking_summaries.append(reasoning_summary)
|
||||
yield ReasoningMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
source="reasoner_model",
|
||||
reasoning=reasoning_summary,
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
self.content_parts.append(
|
||||
ReasoningContent(
|
||||
is_native=True,
|
||||
reasoning=reasoning_summary,
|
||||
signature=self.thinking_signature,
|
||||
)
|
||||
)
|
||||
|
||||
# Plain text content part
|
||||
elif part.text:
|
||||
content = part.text
|
||||
self.text_content = content if self.text_content is None else self.text_content + content
|
||||
yield AssistantMessage(
|
||||
id=self.letta_message_id,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
date=datetime.now(timezone.utc),
|
||||
content=content,
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
self.content_parts.append(
|
||||
TextContent(
|
||||
text=content,
|
||||
)
|
||||
)
|
||||
|
||||
# Tool call function part
|
||||
# NOTE: in gemini, this comes all at once, and the args are JSON dict, not stringified
|
||||
elif part.function_call:
|
||||
function_call = part.function_call
|
||||
|
||||
# Look for call_id, name, and possibly arguments (though likely always empty string)
|
||||
call_id = get_tool_call_id()
|
||||
name = function_call.name
|
||||
arguments = function_call.args # NOTE: dict, not str
|
||||
arguments_str = json.dumps(arguments) # NOTE: use json_dumps?
|
||||
|
||||
self.tool_call_id = call_id
|
||||
self.tool_call_name = name
|
||||
self.tool_call_args = arguments
|
||||
|
||||
if self.tool_call_name and self.tool_call_name in self.requires_approval_tools:
|
||||
yield ApprovalRequestMessage(
|
||||
id=self.letta_message_id,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=name,
|
||||
arguments=arguments_str,
|
||||
tool_call_id=call_id,
|
||||
),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
else:
|
||||
yield ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=name,
|
||||
arguments=arguments_str,
|
||||
tool_call_id=call_id,
|
||||
),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
self.content_parts.append(
|
||||
ToolCallContent(
|
||||
id=call_id,
|
||||
name=name,
|
||||
input=arguments,
|
||||
signature=self.thinking_signature,
|
||||
)
|
||||
)
|
||||
@@ -26,6 +26,7 @@ from letta.schemas.letta_message import (
|
||||
AssistantMessage,
|
||||
HiddenReasoningMessage,
|
||||
LettaMessage,
|
||||
MessageType,
|
||||
ReasoningMessage,
|
||||
SystemMessage,
|
||||
ToolCall,
|
||||
@@ -300,6 +301,13 @@ class Message(BaseMessage):
|
||||
if self.role == MessageRole.assistant:
|
||||
if self.content:
|
||||
messages.extend(self._convert_reasoning_messages(text_is_assistant_message=text_is_assistant_message))
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
if i > 0 and messages[i].message_type == messages[i - 1].message_type:
|
||||
if messages[i].message_type == MessageType.reasoning_message:
|
||||
messages[i - 1].reasoning = messages[i - 1].reasoning + messages.pop(i).reasoning
|
||||
elif messages[i].message_type == MessageType.assistant_message:
|
||||
messages[i - 1].content = messages[i - 1].content + messages.pop(i).content
|
||||
|
||||
if self.tool_calls is not None:
|
||||
messages.extend(
|
||||
self._convert_tool_call_messages(
|
||||
|
||||
@@ -1338,6 +1338,8 @@ async def send_message_streaming(
|
||||
"deepseek",
|
||||
]
|
||||
model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock", "deepseek"]
|
||||
if agent.agent_type == AgentType.letta_v1_agent and agent.llm_config.model_endpoint_type in ["google_ai", "google_vertex"]:
|
||||
model_compatible_token_streaming = True
|
||||
|
||||
# Create a new run for execution tracking
|
||||
if settings.track_agent_run:
|
||||
|
||||
@@ -70,6 +70,7 @@ dependencies = [
|
||||
"ruff[dev]>=0.12.10",
|
||||
"trafilatura",
|
||||
"readability-lxml",
|
||||
"google-genai>=1.15.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
@@ -101,7 +102,6 @@ bedrock = [
|
||||
"boto3>=1.36.24",
|
||||
"aioboto3>=14.3.0",
|
||||
]
|
||||
google = ["google-genai>=1.15.0"]
|
||||
|
||||
# ====== Development ======
|
||||
dev = [
|
||||
|
||||
8
uv.lock
generated
8
uv.lock
generated
@@ -2428,6 +2428,7 @@ dependencies = [
|
||||
{ name = "docstring-parser" },
|
||||
{ name = "exa-py" },
|
||||
{ name = "faker" },
|
||||
{ name = "google-genai" },
|
||||
{ name = "grpcio" },
|
||||
{ name = "grpcio-tools" },
|
||||
{ name = "html2text" },
|
||||
@@ -2526,9 +2527,6 @@ external-tools = [
|
||||
{ name = "turbopuffer" },
|
||||
{ name = "wikipedia" },
|
||||
]
|
||||
google = [
|
||||
{ name = "google-genai" },
|
||||
]
|
||||
modal = [
|
||||
{ name = "modal" },
|
||||
]
|
||||
@@ -2584,7 +2582,7 @@ requires-dist = [
|
||||
{ name = "fastapi", marker = "extra == 'desktop'", specifier = ">=0.115.6" },
|
||||
{ name = "fastapi", marker = "extra == 'server'", specifier = ">=0.115.6" },
|
||||
{ name = "google-cloud-profiler", marker = "extra == 'experimental'", specifier = ">=4.1.0" },
|
||||
{ name = "google-genai", marker = "extra == 'google'", specifier = ">=1.15.0" },
|
||||
{ name = "google-genai", specifier = ">=1.15.0" },
|
||||
{ name = "granian", extras = ["uvloop", "reload"], marker = "extra == 'experimental'", specifier = ">=2.3.2" },
|
||||
{ name = "grpcio", specifier = ">=1.68.1" },
|
||||
{ name = "grpcio-tools", specifier = ">=1.68.1" },
|
||||
@@ -2669,7 +2667,7 @@ requires-dist = [
|
||||
{ name = "wikipedia", marker = "extra == 'desktop'", specifier = ">=1.4.0" },
|
||||
{ name = "wikipedia", marker = "extra == 'external-tools'", specifier = ">=1.4.0" },
|
||||
]
|
||||
provides-extras = ["postgres", "redis", "pinecone", "sqlite", "experimental", "server", "bedrock", "google", "dev", "cloud-tool-sandbox", "modal", "external-tools", "desktop"]
|
||||
provides-extras = ["postgres", "redis", "pinecone", "sqlite", "experimental", "server", "bedrock", "dev", "cloud-tool-sandbox", "modal", "external-tools", "desktop"]
|
||||
|
||||
[[package]]
|
||||
name = "letta-client"
|
||||
|
||||
Reference in New Issue
Block a user