diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 78bc5c62..0bbdc015 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -75,6 +75,68 @@ class LettaAgent(BaseAgent): new_in_context_messages=new_in_context_messages, use_assistant_message=use_assistant_message, usage=usage ) + @trace_method + async def step_stream_no_tokens(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True): + agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor) + current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async( + input_messages, agent_state, self.message_manager, self.actor + ) + tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) + llm_client = LLMClient.create( + provider_type=agent_state.llm_config.model_endpoint_type, + put_inner_thoughts_first=True, + actor=self.actor, + ) + usage = LettaUsageStatistics() + for _ in range(max_steps): + response = await self._get_ai_reply( + llm_client=llm_client, + in_context_messages=current_in_context_messages + new_in_context_messages, + agent_state=agent_state, + tool_rules_solver=tool_rules_solver, + stream=False, + # TODO: also pass in reasoning content + ) + + if not response.choices[0].message.tool_calls: + # TODO: make into a real error + raise ValueError("No tool calls found in response, model must make a tool call") + tool_call = response.choices[0].message.tool_calls[0] + reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons + + persisted_messages, should_continue = await self._handle_ai_response( + tool_call, agent_state, tool_rules_solver, reasoning_content=reasoning + ) + self.response_messages.extend(persisted_messages) + new_in_context_messages.extend(persisted_messages) + + # stream step + # TODO: improve TTFT + filter_user_messages = [m for m in persisted_messages if m.role != "user"] + letta_messages = Message.to_letta_messages_from_list( + filter_user_messages, use_assistant_message=use_assistant_message, reverse=False + ) + for message in letta_messages: + yield f"data: {message.model_dump_json()}\n\n" + + # update usage + # TODO: add run_id + usage.step_count += 1 + usage.completion_tokens += response.usage.completion_tokens + usage.prompt_tokens += response.usage.prompt_tokens + usage.total_tokens += response.usage.total_tokens + + if not should_continue: + break + + # Extend the in context message ids + if not agent_state.message_buffer_autoclear: + message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)] + self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor) + + # Return back usage + yield f"data: {usage.model_dump_json()}\n\n" + async def _step( self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10 ) -> Tuple[List[Message], List[Message], CompletionUsage]: @@ -98,6 +160,9 @@ class LettaAgent(BaseAgent): # TODO: also pass in reasoning content ) + if not response.choices[0].message.tool_calls: + # TODO: make into a real error + raise ValueError("No tool calls found in response, model must make a tool call") tool_call = response.choices[0].message.tool_calls[0] reasoning = [TextContent(text=response.choices[0].message.content)] # reasoning placed into content for legacy reasons @@ -126,7 +191,7 @@ class LettaAgent(BaseAgent): @trace_method async def step_stream( - self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True, stream_tokens: bool = False + self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True ) -> AsyncGenerator[str, None]: """ Main streaming loop that yields partial tokens. @@ -164,6 +229,9 @@ class LettaAgent(BaseAgent): use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs, ) + else: + raise ValueError(f"Streaming not supported for {agent_state.llm_config}") + async for chunk in interface.process(stream): yield f"data: {chunk.model_dump_json()}\n\n" diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 87619ca3..fbfc67cd 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -635,7 +635,7 @@ async def send_message( agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false" feature_enabled = settings.use_experimental or experimental_header.lower() == "true" - model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "google_vertex", "google_ai"] + model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai"] if agent_eligible and feature_enabled and model_compatible: experimental_agent = LettaAgent( @@ -706,13 +706,18 @@ async def send_message_streaming( passage_manager=server.passage_manager, actor=actor, ) - - result = StreamingResponse( - experimental_agent.step_stream( - request.messages, max_steps=10, use_assistant_message=request.use_assistant_message, stream_tokens=request.stream_tokens - ), - media_type="text/event-stream", - ) + if request.stream_tokens: + result = StreamingResponse( + experimental_agent.step_stream(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message), + media_type="text/event-stream", + ) + else: + result = StreamingResponse( + experimental_agent.step_stream_no_tokens( + request.messages, max_steps=10, use_assistant_message=request.use_assistant_message + ), + media_type="text/event-stream", + ) else: result = await server.send_message_to_agent( agent_id=agent_id, diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index d30a6418..e1784820 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -636,6 +636,33 @@ async def test_streaming_tool_call_async_client( assert_tool_call_response(messages, streaming=True) +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_step_streaming_greeting_with_assistant_message( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a streaming message with a synchronous client. + Checks that each chunk in the stream has the correct message types. + """ + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + response = client.agents.messages.create_stream( + agent_id=agent_state.id, + messages=USER_MESSAGE_GREETING, + stream_tokens=False, + ) + messages = [] + for message in response: + messages.append(message) + assert_greeting_with_assistant_message_response(messages, streaming=True) + + @pytest.mark.parametrize( "llm_config", TESTED_LLM_CONFIGS,