From f7755d837a0c18191775f7fda3cca06dd2d06542 Mon Sep 17 00:00:00 2001 From: cthomas Date: Thu, 2 Oct 2025 22:36:04 -0700 Subject: [PATCH] feat: add gemini streaming to new agent loop (#5109) * feat: add gemini streaming to new agent loop * add google as required dependency * support storing all content parts * remove extra google references --- .github/workflows/core-unit-sqlite-test.yaml | 2 +- .github/workflows/core-unit-test.yml | 2 +- .github/workflows/model-sweep.yaml | 2 +- .../send-message-integration-tests.yml | 2 +- .github/workflows/test-lmstudio.yml | 2 +- .github/workflows/test-ollama.yml | 2 +- .github/workflows/test-vllm.yml | 2 +- letta/adapters/simple_llm_stream_adapter.py | 7 + .../interfaces/gemini_streaming_interface.py | 279 ++++++++++++++++++ letta/schemas/message.py | 8 + letta/server/rest_api/routers/v1/agents.py | 2 + pyproject.toml | 2 +- uv.lock | 8 +- 13 files changed, 307 insertions(+), 13 deletions(-) create mode 100644 letta/interfaces/gemini_streaming_interface.py diff --git a/.github/workflows/core-unit-sqlite-test.yaml b/.github/workflows/core-unit-sqlite-test.yaml index 128feb83..fe1a0d2f 100644 --- a/.github/workflows/core-unit-sqlite-test.yaml +++ b/.github/workflows/core-unit-sqlite-test.yaml @@ -25,7 +25,7 @@ jobs: apps/core/** .github/workflows/reusable-test-workflow.yml .github/workflows/core-unit-sqlite-test.yml - install-args: '--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox --extra google --extra sqlite' + install-args: '--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox --extra sqlite' timeout-minutes: 15 ref: ${{ github.event.pull_request.head.sha || github.sha }} diff --git a/.github/workflows/core-unit-test.yml b/.github/workflows/core-unit-test.yml index 81c75b8e..54a18cf7 100644 --- a/.github/workflows/core-unit-test.yml +++ b/.github/workflows/core-unit-test.yml @@ -26,7 +26,7 @@ jobs: ** .github/workflows/reusable-test-workflow.yml .github/workflows/core-unit-test.yml - install-args: '--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox --extra google' + install-args: '--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox' timeout-minutes: 15 ref: ${{ github.event.pull_request.head.sha || github.sha }} matrix-strategy: | diff --git a/.github/workflows/model-sweep.yaml b/.github/workflows/model-sweep.yaml index 66f226d0..49626699 100644 --- a/.github/workflows/model-sweep.yaml +++ b/.github/workflows/model-sweep.yaml @@ -61,7 +61,7 @@ jobs: - name: Install dependencies shell: bash - run: uv sync --extra dev --extra postgres --extra external-tools --extra cloud-tool-sandbox --extra google + run: uv sync --extra dev --extra postgres --extra external-tools --extra cloud-tool-sandbox - name: Migrate database env: LETTA_PG_PORT: 5432 diff --git a/.github/workflows/send-message-integration-tests.yml b/.github/workflows/send-message-integration-tests.yml index fbee82e7..f7d9cf75 100644 --- a/.github/workflows/send-message-integration-tests.yml +++ b/.github/workflows/send-message-integration-tests.yml @@ -25,7 +25,7 @@ jobs: ** .github/workflows/reusable-test-workflow.yml .github/workflows/send-message-integration-tests.yml - install-args: '--extra dev --extra postgres --extra external-tools --extra cloud-tool-sandbox --extra google --extra redis' + install-args: '--extra dev --extra postgres --extra external-tools --extra cloud-tool-sandbox --extra redis' timeout-minutes: 15 runner: '["self-hosted", "medium"]' ref: ${{ github.event.pull_request.head.sha || github.sha }} diff --git a/.github/workflows/test-lmstudio.yml b/.github/workflows/test-lmstudio.yml index dc1c07b2..19893c84 100644 --- a/.github/workflows/test-lmstudio.yml +++ b/.github/workflows/test-lmstudio.yml @@ -32,7 +32,7 @@ jobs: with: test-type: "integration" is-external-pr: ${{ github.event_name == 'pull_request_target' && !contains(github.event.pull_request.labels.*.name, 'safe to test') }} - install-args: "--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox --extra google" + install-args: "--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox" test-command: "uv run pytest -svv tests/" timeout-minutes: 60 runner: '["self-hosted", "gpu", "lmstudio"]' diff --git a/.github/workflows/test-ollama.yml b/.github/workflows/test-ollama.yml index 81f9dad6..e5287d00 100644 --- a/.github/workflows/test-ollama.yml +++ b/.github/workflows/test-ollama.yml @@ -32,7 +32,7 @@ jobs: with: test-type: "integration" is-external-pr: ${{ github.event_name == 'pull_request_target' && !contains(github.event.pull_request.labels.*.name, 'safe to test') }} - install-args: "--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox --extra google" + install-args: "--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox" test-command: "uv run --frozen pytest -svv tests/" timeout-minutes: 60 runner: '["self-hosted", "gpu", "ollama"]' diff --git a/.github/workflows/test-vllm.yml b/.github/workflows/test-vllm.yml index dfed86e5..65b9ba06 100644 --- a/.github/workflows/test-vllm.yml +++ b/.github/workflows/test-vllm.yml @@ -28,7 +28,7 @@ jobs: with: test-type: "integration" is-external-pr: ${{ github.event_name == 'pull_request_target' && !contains(github.event.pull_request.labels.*.name, 'safe to test') }} - install-args: "--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox --extra google" + install-args: "--extra postgres --extra external-tools --extra dev --extra cloud-tool-sandbox" test-command: "uv run --frozen pytest -svv tests/" timeout-minutes: 60 runner: '["self-hosted", "gpu", "vllm"]' diff --git a/letta/adapters/simple_llm_stream_adapter.py b/letta/adapters/simple_llm_stream_adapter.py index 1e55a07e..115f0997 100644 --- a/letta/adapters/simple_llm_stream_adapter.py +++ b/letta/adapters/simple_llm_stream_adapter.py @@ -3,6 +3,7 @@ from typing import AsyncGenerator, List from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter from letta.helpers.datetime_helpers import get_utc_timestamp_ns from letta.interfaces.anthropic_streaming_interface import SimpleAnthropicStreamingInterface +from letta.interfaces.gemini_streaming_interface import SimpleGeminiStreamingInterface from letta.interfaces.openai_streaming_interface import SimpleOpenAIResponsesStreamingInterface, SimpleOpenAIStreamingInterface from letta.schemas.enums import ProviderType from letta.schemas.letta_message import LettaMessage @@ -78,6 +79,12 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter): run_id=self.run_id, step_id=step_id, ) + elif self.llm_config.model_endpoint_type in [ProviderType.google_ai, ProviderType.google_vertex]: + self.interface = SimpleGeminiStreamingInterface( + requires_approval_tools=requires_approval_tools, + run_id=self.run_id, + step_id=step_id, + ) else: raise ValueError(f"Streaming not supported for provider {self.llm_config.model_endpoint_type}") diff --git a/letta/interfaces/gemini_streaming_interface.py b/letta/interfaces/gemini_streaming_interface.py new file mode 100644 index 00000000..b493f475 --- /dev/null +++ b/letta/interfaces/gemini_streaming_interface.py @@ -0,0 +1,279 @@ +import asyncio +import base64 +import json +from collections.abc import AsyncGenerator +from datetime import datetime, timezone +from typing import AsyncIterator, List, Optional + +from google.genai.types import ( + GenerateContentResponse, +) + +from letta.log import get_logger +from letta.schemas.letta_message import ( + ApprovalRequestMessage, + AssistantMessage, + LettaMessage, + ReasoningMessage, + ToolCallDelta, + ToolCallMessage, +) +from letta.schemas.letta_message_content import ( + ReasoningContent, + TextContent, + ToolCallContent, +) +from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType +from letta.schemas.message import Message +from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall +from letta.utils import get_tool_call_id + +logger = get_logger(__name__) + + +class SimpleGeminiStreamingInterface: + """ + Encapsulates the logic for streaming responses from Gemini API: + https://ai.google.dev/gemini-api/docs/text-generation#streaming-responses + """ + + def __init__( + self, + requires_approval_tools: list = [], + run_id: str | None = None, + step_id: str | None = None, + ): + self.run_id = run_id + self.step_id = step_id + + # self.messages = messages + # self.tools = tools + self.requires_approval_tools = requires_approval_tools + # ID responses used + self.message_id = None + + # In Gemini streaming, tool call comes all at once + self.tool_call_id: str | None = None + self.tool_call_name: str | None = None + self.tool_call_args: dict | None = None # NOTE: Not a str! + + # NOTE: signature only is included if tools are present + self.thinking_signature: str | None = None + + # Regular text content too + self.text_content: str | None = None + + # Premake IDs for database writes + self.letta_message_id = Message.generate_id() + # self.model = model + + # Sadly, Gemini's encrypted reasoning logic forces us to store stream parts in state + self.content_parts: List[ReasoningContent | TextContent | ToolCallContent] = [] + + def get_content(self) -> List[ReasoningContent | TextContent | ToolCallContent]: + """This is (unusually) in chunked format, instead of merged""" + # for content in self.content_parts: + # if isinstance(content, ReasoningContent): + # # This assumes there is only one signature per turn + # content.signature = self.thinking_signature + return self.content_parts + + def get_tool_call_object(self) -> ToolCall: + """Useful for agent loop""" + if self.tool_call_id is None: + raise ValueError("No tool call ID available") + if self.tool_call_name is None: + raise ValueError("No tool call name available") + if self.tool_call_args is None: + raise ValueError("No tool call arguments available") + + # TODO use json_dumps? + tool_call_args_str = json.dumps(self.tool_call_args) + + return ToolCall( + id=self.tool_call_id, + type="function", + function=FunctionCall( + name=self.tool_call_name, + arguments=tool_call_args_str, + ), + ) + + async def process( + self, + stream: AsyncIterator[GenerateContentResponse], + ttft_span: Optional["Span"] = None, + ) -> AsyncGenerator[LettaMessage | LettaStopReason, None]: + """ + Iterates over the Gemini stream, yielding SSE events. + It also collects tokens and detects if a tool call is triggered. + """ + prev_message_type = None + message_index = 0 + try: + async for event in stream: + try: + async for message in self._process_event(event, ttft_span, prev_message_type, message_index): + new_message_type = message.message_type + if new_message_type != prev_message_type: + if prev_message_type != None: + message_index += 1 + prev_message_type = new_message_type + yield message + except asyncio.CancelledError as e: + import traceback + + logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc()) + async for message in self._process_event(event, ttft_span, prev_message_type, message_index): + new_message_type = message.message_type + if new_message_type != prev_message_type: + if prev_message_type != None: + message_index += 1 + prev_message_type = new_message_type + yield message + + # Don't raise the exception here + continue + + except Exception as e: + import traceback + + logger.error("Error processing stream: %s\n%s", e, traceback.format_exc()) + if ttft_span: + ttft_span.add_event( + name="stop_reason", + attributes={"stop_reason": StopReasonType.error.value, "error": str(e), "stacktrace": traceback.format_exc()}, + ) + yield LettaStopReason(stop_reason=StopReasonType.error) + raise e + finally: + logger.info("GeminiStreamingInterface: Stream processing complete.") + + async def _process_event( + self, + event: GenerateContentResponse, + ttft_span: Optional["Span"] = None, + prev_message_type: Optional[str] = None, + message_index: int = 0, + ) -> AsyncGenerator[LettaMessage | LettaStopReason, None]: + # Every event has usage data + model info on it, + # so we can continually extract + self.model = event.model_version + self.message_id = event.response_id + usage_metadata = event.usage_metadata + if usage_metadata: + if usage_metadata.prompt_token_count: + self.input_tokens = usage_metadata.prompt_token_count + if usage_metadata.total_token_count: + self.output_tokens = usage_metadata.total_token_count - usage_metadata.prompt_token_count + + if not event.candidates or len(event.candidates) == 0: + return + else: + # NOTE: should always be len 1 + candidate = event.candidates[0] + + for part in candidate.content.parts: + # NOTE: the thought signature often comes after the thought text, eg with the tool call + if part.thought_signature: + # NOTE: the thought_signature comes on the Part with the function_call + thought_signature = part.thought_signature + self.thinking_signature = thought_signature + # TODO: support reasoning + # yield ReasoningMessage( + # id=self.letta_message_id, + # date=datetime.now(timezone.utc).isoformat(), + # otid=Message.generate_otid_from_id(self.letta_message_id, message_index), + # source="reasoner_model", + # reasoning="", + # signature=base64.b64encode(thought_signature).decode('utf-8'), + # ) + + # Thinking summary content part (bool means text is thought part) + if part.thought: + reasoning_summary = part.text + self.thinking_summaries.append(reasoning_summary) + yield ReasoningMessage( + id=self.letta_message_id, + date=datetime.now(timezone.utc).isoformat(), + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), + source="reasoner_model", + reasoning=reasoning_summary, + run_id=self.run_id, + step_id=self.step_id, + ) + self.content_parts.append( + ReasoningContent( + is_native=True, + reasoning=reasoning_summary, + signature=self.thinking_signature, + ) + ) + + # Plain text content part + elif part.text: + content = part.text + self.text_content = content if self.text_content is None else self.text_content + content + yield AssistantMessage( + id=self.letta_message_id, + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), + date=datetime.now(timezone.utc), + content=content, + run_id=self.run_id, + step_id=self.step_id, + ) + self.content_parts.append( + TextContent( + text=content, + ) + ) + + # Tool call function part + # NOTE: in gemini, this comes all at once, and the args are JSON dict, not stringified + elif part.function_call: + function_call = part.function_call + + # Look for call_id, name, and possibly arguments (though likely always empty string) + call_id = get_tool_call_id() + name = function_call.name + arguments = function_call.args # NOTE: dict, not str + arguments_str = json.dumps(arguments) # NOTE: use json_dumps? + + self.tool_call_id = call_id + self.tool_call_name = name + self.tool_call_args = arguments + + if self.tool_call_name and self.tool_call_name in self.requires_approval_tools: + yield ApprovalRequestMessage( + id=self.letta_message_id, + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), + date=datetime.now(timezone.utc), + tool_call=ToolCallDelta( + name=name, + arguments=arguments_str, + tool_call_id=call_id, + ), + run_id=self.run_id, + step_id=self.step_id, + ) + else: + yield ToolCallMessage( + id=self.letta_message_id, + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), + date=datetime.now(timezone.utc), + tool_call=ToolCallDelta( + name=name, + arguments=arguments_str, + tool_call_id=call_id, + ), + run_id=self.run_id, + step_id=self.step_id, + ) + self.content_parts.append( + ToolCallContent( + id=call_id, + name=name, + input=arguments, + signature=self.thinking_signature, + ) + ) diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 85e80806..769ffcc2 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -26,6 +26,7 @@ from letta.schemas.letta_message import ( AssistantMessage, HiddenReasoningMessage, LettaMessage, + MessageType, ReasoningMessage, SystemMessage, ToolCall, @@ -300,6 +301,13 @@ class Message(BaseMessage): if self.role == MessageRole.assistant: if self.content: messages.extend(self._convert_reasoning_messages(text_is_assistant_message=text_is_assistant_message)) + for i in range(len(messages) - 1, -1, -1): + if i > 0 and messages[i].message_type == messages[i - 1].message_type: + if messages[i].message_type == MessageType.reasoning_message: + messages[i - 1].reasoning = messages[i - 1].reasoning + messages.pop(i).reasoning + elif messages[i].message_type == MessageType.assistant_message: + messages[i - 1].content = messages[i - 1].content + messages.pop(i).content + if self.tool_calls is not None: messages.extend( self._convert_tool_call_messages( diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index a5e675c9..f7860ca4 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -1338,6 +1338,8 @@ async def send_message_streaming( "deepseek", ] model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock", "deepseek"] + if agent.agent_type == AgentType.letta_v1_agent and agent.llm_config.model_endpoint_type in ["google_ai", "google_vertex"]: + model_compatible_token_streaming = True # Create a new run for execution tracking if settings.track_agent_run: diff --git a/pyproject.toml b/pyproject.toml index 26d8eb05..4a30ceae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ dependencies = [ "ruff[dev]>=0.12.10", "trafilatura", "readability-lxml", + "google-genai>=1.15.0", ] [project.scripts] @@ -101,7 +102,6 @@ bedrock = [ "boto3>=1.36.24", "aioboto3>=14.3.0", ] -google = ["google-genai>=1.15.0"] # ====== Development ====== dev = [ diff --git a/uv.lock b/uv.lock index a93f5b19..ff39f136 100644 --- a/uv.lock +++ b/uv.lock @@ -2428,6 +2428,7 @@ dependencies = [ { name = "docstring-parser" }, { name = "exa-py" }, { name = "faker" }, + { name = "google-genai" }, { name = "grpcio" }, { name = "grpcio-tools" }, { name = "html2text" }, @@ -2526,9 +2527,6 @@ external-tools = [ { name = "turbopuffer" }, { name = "wikipedia" }, ] -google = [ - { name = "google-genai" }, -] modal = [ { name = "modal" }, ] @@ -2584,7 +2582,7 @@ requires-dist = [ { name = "fastapi", marker = "extra == 'desktop'", specifier = ">=0.115.6" }, { name = "fastapi", marker = "extra == 'server'", specifier = ">=0.115.6" }, { name = "google-cloud-profiler", marker = "extra == 'experimental'", specifier = ">=4.1.0" }, - { name = "google-genai", marker = "extra == 'google'", specifier = ">=1.15.0" }, + { name = "google-genai", specifier = ">=1.15.0" }, { name = "granian", extras = ["uvloop", "reload"], marker = "extra == 'experimental'", specifier = ">=2.3.2" }, { name = "grpcio", specifier = ">=1.68.1" }, { name = "grpcio-tools", specifier = ">=1.68.1" }, @@ -2669,7 +2667,7 @@ requires-dist = [ { name = "wikipedia", marker = "extra == 'desktop'", specifier = ">=1.4.0" }, { name = "wikipedia", marker = "extra == 'external-tools'", specifier = ">=1.4.0" }, ] -provides-extras = ["postgres", "redis", "pinecone", "sqlite", "experimental", "server", "bedrock", "google", "dev", "cloud-tool-sandbox", "modal", "external-tools", "desktop"] +provides-extras = ["postgres", "redis", "pinecone", "sqlite", "experimental", "server", "bedrock", "dev", "cloud-tool-sandbox", "modal", "external-tools", "desktop"] [[package]] name = "letta-client"