diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index 180b38ab..123d46ee 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -1,6 +1,6 @@ import json import uuid -from typing import List, Optional +from typing import AsyncIterator, List, Optional from google import genai from google.genai import errors @@ -138,6 +138,15 @@ class GoogleVertexClient(LLMClientBase): raise RuntimeError("Failed to get response data after all retries") return response_data + @trace_method + async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncIterator[GenerateContentResponse]: + client = self._get_client() + return await client.aio.models.generate_content_stream( + model=llm_config.model, + contents=request_data["contents"], + config=request_data["config"], + ) + @staticmethod def add_dummy_model_messages(messages: List[dict]) -> List[dict]: """Google AI API requires all function call returns are immediately followed by a 'model' role message.