From 3ba79db8599d6f7ad92813b3358e9ae9be9ca78e Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 3 Apr 2025 19:40:48 -0700 Subject: [PATCH] feat: Enable Anthropic streaming on new agent loop (#1550) --- letta/agent.py | 1 + letta/agents/base_agent.py | 11 +- letta/agents/helpers.py | 52 +++ letta/agents/letta_agent.py | 147 +++++--- letta/agents/voice_agent.py | 4 +- .../anthropic_streaming_interface.py | 323 ++++++++++++++++++ letta/llm_api/anthropic_client.py | 52 ++- letta/schemas/enums.py | 5 +- letta/schemas/letta_message.py | 6 +- letta/schemas/llm_config.py | 18 +- letta/server/rest_api/routers/v1/agents.py | 88 ++--- letta/server/rest_api/utils.py | 19 +- letta/services/tool_executor/tool_executor.py | 1 + tests/integration_test_experimental.py | 54 ++- 14 files changed, 652 insertions(+), 129 deletions(-) create mode 100644 letta/agents/helpers.py create mode 100644 letta/interfaces/anthropic_streaming_interface.py diff --git a/letta/agent.py b/letta/agent.py index ed76ad30..978197e7 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -295,6 +295,7 @@ class Agent(BaseAgent): and not self.supports_structured_output and len(self.tool_rules_solver.init_tool_rules) > 0 ): + # TODO: This just seems wrong? What if there are more than 1 init tool rules? force_tool_call = self.tool_rules_solver.init_tool_rules[0].tool_name # Force a tool call if exactly one tool is specified elif step_count is not None and step_count > 0 and len(allowed_tool_names) == 1: diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index c9b31ece..8a793088 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -1,9 +1,10 @@ from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Optional +from typing import Any, AsyncGenerator, Optional, Union import openai -from letta.schemas.letta_message import UserMessage +from letta.schemas.enums import MessageStreamStatus +from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, UserMessage from letta.schemas.letta_response import LettaResponse from letta.schemas.user import User from letta.services.agent_manager import AgentManager @@ -39,9 +40,11 @@ class BaseAgent(ABC): raise NotImplementedError @abstractmethod - async def step_stream(self, input_message: UserMessage, max_steps: int = 10) -> AsyncGenerator[str, None]: + async def step_stream( + self, input_message: UserMessage, max_steps: int = 10 + ) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]: """ - Main async execution loop for the agent. Implementations must yield messages as SSE events. + Main streaming execution loop for the agent. """ raise NotImplementedError diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py new file mode 100644 index 00000000..9f614510 --- /dev/null +++ b/letta/agents/helpers.py @@ -0,0 +1,52 @@ +from typing import Dict, List, Tuple + +from letta.schemas.agent import AgentState +from letta.schemas.letta_response import LettaResponse +from letta.schemas.message import Message +from letta.schemas.usage import LettaUsageStatistics +from letta.schemas.user import User +from letta.server.rest_api.utils import create_user_message +from letta.services.message_manager import MessageManager + + +def _create_letta_response(new_in_context_messages: list[Message], use_assistant_message: bool) -> LettaResponse: + """ + Converts the newly created/persisted messages into a LettaResponse. + """ + response_messages = [] + for msg in new_in_context_messages: + response_messages.extend(msg.to_letta_message(use_assistant_message=use_assistant_message)) + return LettaResponse(messages=response_messages, usage=LettaUsageStatistics()) + + +def _prepare_in_context_messages( + input_message: Dict, agent_state: AgentState, message_manager: MessageManager, actor: User +) -> Tuple[List[Message], List[Message]]: + """ + Prepares in-context messages for an agent, based on the current state and a new user input. + + Args: + input_message (Dict): The new user input message to process. + agent_state (AgentState): The current state of the agent, including message buffer config. + message_manager (MessageManager): The manager used to retrieve and create messages. + actor (User): The user performing the action, used for access control and attribution. + + Returns: + Tuple[List[Message], List[Message]]: A tuple containing: + - The current in-context messages (existing context for the agent). + - The new in-context messages (messages created from the new input). + """ + + if agent_state.message_buffer_autoclear: + # If autoclear is enabled, only include the most recent system message (usually at index 0) + current_in_context_messages = [message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=actor)[0]] + else: + # Otherwise, include the full list of messages by ID for context + current_in_context_messages = message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=actor) + + # Create a new user message from the input and store it + new_in_context_messages = message_manager.create_many_messages( + [create_user_message(input_message=input_message, agent_id=agent_state.id, actor=actor)], actor=actor + ) + + return current_in_context_messages, new_in_context_messages diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index c0946788..128a1073 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -1,27 +1,32 @@ import asyncio import json import uuid -from typing import Any, AsyncGenerator, Dict, List, Tuple +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union from openai import AsyncStream from openai.types.chat import ChatCompletion, ChatCompletionChunk from letta.agents.base_agent import BaseAgent -from letta.constants import DEFAULT_MESSAGE_TOOL +from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages from letta.helpers import ToolRulesSolver from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.tool_execution_helper import enable_strict_mode +from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface from letta.llm_api.llm_client import LLMClient +from letta.llm_api.llm_client_base import LLMClientBase +from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.log import get_logger from letta.orm.enums import ToolType from letta.schemas.agent import AgentState +from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import AssistantMessage +from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.letta_response import LettaResponse from letta.schemas.message import Message, MessageUpdate from letta.schemas.openai.chat_completion_request import UserMessage -from letta.schemas.usage import LettaUsageStatistics +from letta.schemas.openai.chat_completion_response import ToolCall from letta.schemas.user import User -from letta.server.rest_api.utils import create_tool_call_messages_from_openai_response, create_user_message +from letta.server.rest_api.utils import create_letta_messages_from_llm_response from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager from letta.services.helpers.agent_manager_helper import compile_system_message @@ -58,76 +63,130 @@ class LettaAgent(BaseAgent): async def step(self, input_message: UserMessage, max_steps: int = 10) -> LettaResponse: input_message = self.pre_process_input_message(input_message) agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor) - # TODO: Extend to beyond just system message - system_message = [self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor)[0]] - persisted_letta_messages = self.message_manager.create_many_messages( - [create_user_message(input_message=input_message, agent_id=agent_state.id, actor=self.actor)], actor=self.actor + current_in_context_messages, new_in_context_messages = _prepare_in_context_messages( + input_message, agent_state, self.message_manager, self.actor ) tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) - - # TODO: Note that we do absolutely 0 pulling in of in-context messages here - # TODO: This is specific to B, and needs to be changed + llm_client = LLMClient.create( + llm_config=agent_state.llm_config, + put_inner_thoughts_first=True, + ) for step in range(max_steps): response = await self._get_ai_reply( - in_context_messages=system_message + persisted_letta_messages, + llm_client=llm_client, + in_context_messages=current_in_context_messages + new_in_context_messages, agent_state=agent_state, tool_rules_solver=tool_rules_solver, + stream=False, ) - persisted_messages, should_continue = await self._handle_ai_response(response, agent_state, tool_rules_solver) - persisted_letta_messages.extend(persisted_messages) + + tool_call = response.choices[0].message.tool_calls[0] + persisted_messages, should_continue = await self._handle_ai_response(tool_call, agent_state, tool_rules_solver) + new_in_context_messages.extend(persisted_messages) if not should_continue: break - # Persist messages - # Translate to letta response messages - response_messages = [] - for message in persisted_letta_messages: - response_messages += message.to_letta_message(use_assistant_message=self.use_assistant_message) + # Extend the in context message ids + if not agent_state.message_buffer_autoclear: + message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)] + self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor) - return LettaResponse( - messages=response_messages, - # TODO: Actually populate this - usage=LettaUsageStatistics(), - ) + return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=self.use_assistant_message) - async def step_stream(self, input_message: UserMessage, max_steps: int = 10) -> AsyncGenerator[str, None]: + @trace_method + async def step_stream( + self, input_message: UserMessage, max_steps: int = 10, use_assistant_message: bool = False + ) -> AsyncGenerator[str, None]: """ Main streaming loop that yields partial tokens. Whenever we detect a tool call, we yield from _handle_ai_response as well. """ - raise NotImplementedError("Not implemented for letta agent") + input_message = self.pre_process_input_message(input_message) + agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor) + current_in_context_messages, new_in_context_messages = _prepare_in_context_messages( + input_message, agent_state, self.message_manager, self.actor + ) + tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) + llm_client = LLMClient.create( + llm_config=agent_state.llm_config, + put_inner_thoughts_first=True, + ) + + for step in range(max_steps): + stream = await self._get_ai_reply( + llm_client=llm_client, + in_context_messages=current_in_context_messages + new_in_context_messages, + agent_state=agent_state, + tool_rules_solver=tool_rules_solver, + stream=True, + ) + + # TODO: THIS IS INCREDIBLY UGLY + # TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED + interface = AnthropicStreamingInterface( + use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=llm_client.llm_config.put_inner_thoughts_in_kwargs + ) + async for chunk in interface.process(stream): + yield f"data: {chunk.model_dump_json()}\n\n" + + # Process resulting stream content + tool_call = interface.get_tool_call_object() + reasoning_content = interface.get_reasoning_content() + persisted_messages, should_continue = await self._handle_ai_response( + tool_call, + agent_state, + tool_rules_solver, + reasoning_content=reasoning_content, + pre_computed_assistant_message_id=interface.letta_assistant_message_id, + pre_computed_tool_message_id=interface.letta_tool_message_id, + ) + new_in_context_messages.extend(persisted_messages) + + if not should_continue: + break + + # Extend the in context message ids + if not agent_state.message_buffer_autoclear: + message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)] + self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor) + + # TODO: Also yield out a letta usage stats SSE + + yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n" @trace_method async def _get_ai_reply( self, + llm_client: LLMClientBase, in_context_messages: List[Message], agent_state: AgentState, tool_rules_solver: ToolRulesSolver, + stream: bool, ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]: in_context_messages = self._rebuild_memory(in_context_messages, agent_state) tools = [ t for t in agent_state.tools - if t.tool_type in {ToolType.CUSTOM} - or (t.tool_type == ToolType.LETTA_CORE and t.name == DEFAULT_MESSAGE_TOOL) + if t.tool_type in {ToolType.CUSTOM, ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE} or (t.tool_type == ToolType.LETTA_MULTI_AGENT_CORE and t.name == "send_message_to_agents_matching_tags") ] - valid_tool_names = set(tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools]))) - allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in valid_tool_names] + valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools])) + # TODO: Copied from legacy agent loop, so please be cautious + # Set force tool + force_tool_call = None + if len(valid_tool_names) == 1: + force_tool_call = valid_tool_names[0] - llm_client = LLMClient.create( - llm_config=agent_state.llm_config, - put_inner_thoughts_first=True, - ) + allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)] response = await llm_client.send_llm_request_async( messages=in_context_messages, tools=allowed_tools, - tool_call=None, - stream=False, + force_tool_call=force_tool_call, + stream=stream, ) return response @@ -135,19 +194,18 @@ class LettaAgent(BaseAgent): @trace_method async def _handle_ai_response( self, - chat_completion_response: ChatCompletion, + tool_call: ToolCall, agent_state: AgentState, tool_rules_solver: ToolRulesSolver, + reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None, + pre_computed_assistant_message_id: Optional[str] = None, + pre_computed_tool_message_id: Optional[str] = None, ) -> Tuple[List[Message], bool]: """ Now that streaming is done, handle the final AI response. This might yield additional SSE tokens if we do stalling. At the end, set self._continue_execution accordingly. """ - # TODO: Some key assumptions here. - # TODO: Assume every call has a tool call, i.e. tool_choice is REQUIRED - tool_call = chat_completion_response.choices[0].message.tool_calls[0] - tool_call_name = tool_call.function.name tool_call_args_str = tool_call.function.arguments @@ -158,6 +216,8 @@ class LettaAgent(BaseAgent): # Get request heartbeats and coerce to bool request_heartbeat = tool_args.pop("request_heartbeat", False) + # Pre-emptively pop out inner_thoughts + tool_args.pop(INNER_THOUGHTS_KWARG, "") # So this is necessary, because sometimes non-structured outputs makes mistakes if not isinstance(request_heartbeat, bool): @@ -186,7 +246,7 @@ class LettaAgent(BaseAgent): continue_stepping = True # 5. Persist to DB - tool_call_messages = create_tool_call_messages_from_openai_response( + tool_call_messages = create_letta_messages_from_llm_response( agent_id=agent_state.id, model=agent_state.llm_config.model, function_name=tool_call_name, @@ -196,6 +256,9 @@ class LettaAgent(BaseAgent): function_response=tool_result, actor=self.actor, add_heartbeat_request_system_message=continue_stepping, + reasoning_content=reasoning_content, + pre_computed_assistant_message_id=pre_computed_assistant_message_id, + pre_computed_tool_message_id=pre_computed_tool_message_id, ) persisted_messages = self.message_manager.create_many_messages(tool_call_messages, actor=self.actor) diff --git a/letta/agents/voice_agent.py b/letta/agents/voice_agent.py index 48425589..16f8ff97 100644 --- a/letta/agents/voice_agent.py +++ b/letta/agents/voice_agent.py @@ -34,7 +34,7 @@ from letta.schemas.user import User from letta.server.rest_api.utils import ( convert_letta_messages_to_openai, create_assistant_messages_from_openai_response, - create_tool_call_messages_from_openai_response, + create_letta_messages_from_llm_response, create_user_message, ) from letta.services.agent_manager import AgentManager @@ -207,7 +207,7 @@ class VoiceAgent(BaseAgent): in_memory_message_history.append(heartbeat_user_message.model_dump()) # 5. Also store in DB - tool_call_messages = create_tool_call_messages_from_openai_response( + tool_call_messages = create_letta_messages_from_llm_response( agent_id=agent_state.id, model=agent_state.llm_config.model, function_name=tool_call_name, diff --git a/letta/interfaces/anthropic_streaming_interface.py b/letta/interfaces/anthropic_streaming_interface.py new file mode 100644 index 00000000..fc0741b9 --- /dev/null +++ b/letta/interfaces/anthropic_streaming_interface.py @@ -0,0 +1,323 @@ +from datetime import datetime, timezone +from enum import Enum +from typing import AsyncGenerator, List, Union + +from anthropic import AsyncStream +from anthropic.types.beta import ( + BetaInputJSONDelta, + BetaRawContentBlockDeltaEvent, + BetaRawContentBlockStartEvent, + BetaRawContentBlockStopEvent, + BetaRawMessageDeltaEvent, + BetaRawMessageStartEvent, + BetaRawMessageStopEvent, + BetaRawMessageStreamEvent, + BetaRedactedThinkingBlock, + BetaSignatureDelta, + BetaTextBlock, + BetaTextDelta, + BetaThinkingBlock, + BetaThinkingDelta, + BetaToolUseBlock, +) + +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG +from letta.local_llm.constants import INNER_THOUGHTS_KWARG +from letta.log import get_logger +from letta.schemas.letta_message import ( + AssistantMessage, + HiddenReasoningMessage, + LettaMessage, + ReasoningMessage, + ToolCallDelta, + ToolCallMessage, +) +from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent +from letta.schemas.message import Message +from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall +from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser + +logger = get_logger(__name__) + + +# TODO: These modes aren't used right now - but can be useful we do multiple sequential tool calling within one Claude message +class EventMode(Enum): + TEXT = "TEXT" + TOOL_USE = "TOOL_USE" + THINKING = "THINKING" + REDACTED_THINKING = "REDACTED_THINKING" + + +class AnthropicStreamingInterface: + """ + Encapsulates the logic for streaming responses from Anthropic. + This class handles parsing of partial tokens, pre-execution messages, + and detection of tool call events. + """ + + def __init__(self, use_assistant_message: bool = False, put_inner_thoughts_in_kwarg: bool = False): + self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser() + self.use_assistant_message = use_assistant_message + + # Premake IDs for database writes + self.letta_assistant_message_id = Message.generate_id() + self.letta_tool_message_id = Message.generate_id() + + self.anthropic_mode = None + self.message_id = None + self.accumulated_inner_thoughts = [] + self.tool_call_id = None + self.tool_call_name = None + self.accumulated_tool_call_args = [] + self.previous_parse = {} + + # usage trackers + self.input_tokens = 0 + self.output_tokens = 0 + + # reasoning object trackers + self.reasoning_messages = [] + + # Buffer to hold tool call messages until inner thoughts are complete + self.tool_call_buffer = [] + self.inner_thoughts_complete = False + self.put_inner_thoughts_in_kwarg = put_inner_thoughts_in_kwarg + + def get_tool_call_object(self) -> ToolCall: + """Useful for agent loop""" + return ToolCall( + id=self.tool_call_id, function=FunctionCall(arguments="".join(self.accumulated_tool_call_args), name=self.tool_call_name) + ) + + def _check_inner_thoughts_complete(self, combined_args: str) -> bool: + """ + Check if inner thoughts are complete in the current tool call arguments + by looking for a closing quote after the inner_thoughts field + """ + if not self.put_inner_thoughts_in_kwarg: + # None of the things should have inner thoughts in kwargs + return True + else: + parsed = self.optimistic_json_parser.parse(combined_args) + # TODO: This will break on tools with 0 input + return len(parsed.keys()) > 1 and INNER_THOUGHTS_KWARG in parsed.keys() + + async def process(self, stream: AsyncStream[BetaRawMessageStreamEvent]) -> AsyncGenerator[LettaMessage, None]: + async with stream: + async for event in stream: + # TODO: Support BetaThinkingBlock, BetaRedactedThinkingBlock + if isinstance(event, BetaRawContentBlockStartEvent): + content = event.content_block + + if isinstance(content, BetaTextBlock): + self.anthropic_mode = EventMode.TEXT + # TODO: Can capture citations, etc. + elif isinstance(content, BetaToolUseBlock): + self.anthropic_mode = EventMode.TOOL_USE + self.tool_call_id = content.id + self.tool_call_name = content.name + self.inner_thoughts_complete = False + + if not self.use_assistant_message: + # Buffer the initial tool call message instead of yielding immediately + tool_call_msg = ToolCallMessage( + id=self.letta_tool_message_id, + tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id), + date=datetime.now(timezone.utc).isoformat(), + ) + self.tool_call_buffer.append(tool_call_msg) + elif isinstance(content, BetaThinkingBlock): + self.anthropic_mode = EventMode.THINKING + # TODO: Can capture signature, etc. + elif isinstance(content, BetaRedactedThinkingBlock): + self.anthropic_mode = EventMode.REDACTED_THINKING + + hidden_reasoning_message = HiddenReasoningMessage( + id=self.letta_assistant_message_id, + state="redacted", + hidden_reasoning=content.data, + date=datetime.now(timezone.utc).isoformat(), + ) + self.reasoning_messages.append(hidden_reasoning_message) + yield hidden_reasoning_message + + elif isinstance(event, BetaRawContentBlockDeltaEvent): + delta = event.delta + + if isinstance(delta, BetaTextDelta): + # Safety check + if not self.anthropic_mode == EventMode.TEXT: + raise RuntimeError( + f"Streaming integrity failed - received BetaTextDelta object while not in TEXT EventMode: {delta}" + ) + + # TODO: Strip out more robustly, this is pretty hacky lol + delta.text = delta.text.replace("", "") + self.accumulated_inner_thoughts.append(delta.text) + + reasoning_message = ReasoningMessage( + id=self.letta_assistant_message_id, + reasoning=self.accumulated_inner_thoughts[-1], + date=datetime.now(timezone.utc).isoformat(), + ) + self.reasoning_messages.append(reasoning_message) + yield reasoning_message + + elif isinstance(delta, BetaInputJSONDelta): + if not self.anthropic_mode == EventMode.TOOL_USE: + raise RuntimeError( + f"Streaming integrity failed - received BetaInputJSONDelta object while not in TOOL_USE EventMode: {delta}" + ) + + self.accumulated_tool_call_args.append(delta.partial_json) + combined_args = "".join(self.accumulated_tool_call_args) + current_parsed = self.optimistic_json_parser.parse(combined_args) + + # Start detecting a difference in inner thoughts + previous_inner_thoughts = self.previous_parse.get(INNER_THOUGHTS_KWARG, "") + current_inner_thoughts = current_parsed.get(INNER_THOUGHTS_KWARG, "") + inner_thoughts_diff = current_inner_thoughts[len(previous_inner_thoughts) :] + + if inner_thoughts_diff: + reasoning_message = ReasoningMessage( + id=self.letta_assistant_message_id, + reasoning=inner_thoughts_diff, + date=datetime.now(timezone.utc).isoformat(), + ) + self.reasoning_messages.append(reasoning_message) + yield reasoning_message + + # Check if inner thoughts are complete - if so, flush the buffer + if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(combined_args): + self.inner_thoughts_complete = True + # Flush all buffered tool call messages + for buffered_msg in self.tool_call_buffer: + yield buffered_msg + self.tool_call_buffer = [] + + # Start detecting special case of "send_message" + if self.tool_call_name == DEFAULT_MESSAGE_TOOL and self.use_assistant_message: + previous_send_message = self.previous_parse.get(DEFAULT_MESSAGE_TOOL_KWARG, "") + current_send_message = current_parsed.get(DEFAULT_MESSAGE_TOOL_KWARG, "") + send_message_diff = current_send_message[len(previous_send_message) :] + + # Only stream out if it's not an empty string + if send_message_diff: + yield AssistantMessage( + id=self.letta_assistant_message_id, + content=[TextContent(text=send_message_diff)], + date=datetime.now(timezone.utc).isoformat(), + ) + else: + # Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status + tool_call_msg = ToolCallMessage( + id=self.letta_tool_message_id, + tool_call=ToolCallDelta(arguments=delta.partial_json), + date=datetime.now(timezone.utc).isoformat(), + ) + + if self.inner_thoughts_complete: + yield tool_call_msg + else: + self.tool_call_buffer.append(tool_call_msg) + + # Set previous parse + self.previous_parse = current_parsed + elif isinstance(delta, BetaThinkingDelta): + # Safety check + if not self.anthropic_mode == EventMode.THINKING: + raise RuntimeError( + f"Streaming integrity failed - received BetaThinkingBlock object while not in THINKING EventMode: {delta}" + ) + + reasoning_message = ReasoningMessage( + id=self.letta_assistant_message_id, + source="reasoner_model", + reasoning=delta.thinking, + date=datetime.now(timezone.utc).isoformat(), + ) + self.reasoning_messages.append(reasoning_message) + yield reasoning_message + elif isinstance(delta, BetaSignatureDelta): + # Safety check + if not self.anthropic_mode == EventMode.THINKING: + raise RuntimeError( + f"Streaming integrity failed - received BetaSignatureDelta object while not in THINKING EventMode: {delta}" + ) + + reasoning_message = ReasoningMessage( + id=self.letta_assistant_message_id, + source="reasoner_model", + reasoning="", + date=datetime.now(timezone.utc).isoformat(), + signature=delta.signature, + ) + self.reasoning_messages.append(reasoning_message) + yield reasoning_message + elif isinstance(event, BetaRawMessageStartEvent): + self.message_id = event.message.id + self.input_tokens += event.message.usage.input_tokens + self.output_tokens += event.message.usage.output_tokens + elif isinstance(event, BetaRawMessageDeltaEvent): + self.output_tokens += event.usage.output_tokens + elif isinstance(event, BetaRawMessageStopEvent): + # Don't do anything here! We don't want to stop the stream. + pass + elif isinstance(event, BetaRawContentBlockStopEvent): + # If we're exiting a tool use block and there are still buffered messages, + # we should flush them now + if self.anthropic_mode == EventMode.TOOL_USE and self.tool_call_buffer: + for buffered_msg in self.tool_call_buffer: + yield buffered_msg + self.tool_call_buffer = [] + + self.anthropic_mode = None + + def get_reasoning_content(self) -> List[Union[TextContent, ReasoningContent, RedactedReasoningContent]]: + def _process_group( + group: List[Union[ReasoningMessage, HiddenReasoningMessage]], group_type: str + ) -> Union[TextContent, ReasoningContent, RedactedReasoningContent]: + if group_type == "reasoning": + reasoning_text = "".join(chunk.reasoning for chunk in group) + is_native = any(chunk.source == "reasoner_model" for chunk in group) + signature = next((chunk.signature for chunk in group if chunk.signature is not None), None) + if is_native: + return ReasoningContent(is_native=is_native, reasoning=reasoning_text, signature=signature) + else: + return TextContent(text=reasoning_text) + elif group_type == "redacted": + redacted_text = "".join(chunk.hidden_reasoning for chunk in group if chunk.hidden_reasoning is not None) + return RedactedReasoningContent(data=redacted_text) + else: + raise ValueError("Unexpected group type") + + merged = [] + current_group = [] + current_group_type = None # "reasoning" or "redacted" + + for msg in self.reasoning_messages: + # Determine the type of the current message + if isinstance(msg, HiddenReasoningMessage): + msg_type = "redacted" + elif isinstance(msg, ReasoningMessage): + msg_type = "reasoning" + else: + raise ValueError("Unexpected message type") + + # Initialize group type if not set + if current_group_type is None: + current_group_type = msg_type + + # If the type changes, process the current group + if msg_type != current_group_type: + merged.append(_process_group(current_group, current_group_type)) + current_group = [] + current_group_type = msg_type + + current_group.append(msg) + + # Process the final group, if any. + if current_group: + merged.append(_process_group(current_group, current_group_type)) + + return merged diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index 3e4867b0..99e315b8 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -3,7 +3,9 @@ import re from typing import List, Optional, Union import anthropic +from anthropic import AsyncStream from anthropic.types import Message as AnthropicMessage +from anthropic.types.beta import BetaRawMessageStreamEvent from letta.errors import ( ContextWindowExceededError, @@ -28,6 +30,7 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.schemas.openai.chat_completion_response import Message as ChoiceMessage from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics from letta.services.provider_manager import ProviderManager +from letta.tracing import trace_method DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence." @@ -46,18 +49,28 @@ class AnthropicClient(LLMClientBase): response = await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"]) return response.model_dump() + @trace_method + async def stream_async(self, request_data: dict) -> AsyncStream[BetaRawMessageStreamEvent]: + client = self._get_anthropic_client(async_client=True) + request_data["stream"] = True + return await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"]) + + @trace_method def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]: override_key = ProviderManager().get_anthropic_override_key() if async_client: return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic() return anthropic.Anthropic(api_key=override_key) if override_key else anthropic.Anthropic() + @trace_method def build_request_data( self, messages: List[PydanticMessage], tools: List[dict], force_tool_call: Optional[str] = None, ) -> dict: + # TODO: This needs to get cleaned up. The logic here is pretty confusing. + # TODO: I really want to get rid of prefixing, it's a recipe for disaster code maintenance wise prefix_fill = True if not self.use_tool_naming: raise NotImplementedError("Only tool calling supported on Anthropic API requests") @@ -73,11 +86,6 @@ class AnthropicClient(LLMClientBase): # Extended Thinking if self.llm_config.enable_reasoner: - assert ( - self.llm_config.max_reasoning_tokens is not None and self.llm_config.max_reasoning_tokens < self.llm_config.max_tokens - ), "max tokens must be greater than thinking budget" - assert not self.llm_config.put_inner_thoughts_in_kwargs, "extended thinking not compatible with put_inner_thoughts_in_kwargs" - data["thinking"] = { "type": "enabled", "budget_tokens": self.llm_config.max_reasoning_tokens, @@ -89,15 +97,35 @@ class AnthropicClient(LLMClientBase): prefix_fill = False # Tools - tools_for_request = ( - [Tool(function=f) for f in tools if f["name"] == force_tool_call] - if force_tool_call is not None - else [Tool(function=f) for f in tools] - ) - if force_tool_call is not None: - self.llm_config.put_inner_thoughts_in_kwargs = True # why do we do this ? + # For an overview on tool choice: + # https://docs.anthropic.com/en/docs/build-with-claude/tool-use/overview + if not tools: + # Special case for summarization path + tools_for_request = None + tool_choice = None + elif force_tool_call is not None: + tool_choice = {"type": "tool", "name": force_tool_call} + tools_for_request = [Tool(function=f) for f in tools if f["name"] == force_tool_call] + + # need to have this setting to be able to put inner thoughts in kwargs + if not self.llm_config.put_inner_thoughts_in_kwargs: + logger.warning( + f"Force setting put_inner_thoughts_in_kwargs to True for Claude because there is a forced tool call: {force_tool_call}" + ) + self.llm_config.put_inner_thoughts_in_kwargs = True + else: + if self.llm_config.put_inner_thoughts_in_kwargs: + # tool_choice_type other than "auto" only plays nice if thinking goes inside the tool calls + tool_choice = {"type": "any", "disable_parallel_tool_use": True} + else: + tool_choice = {"type": "auto", "disable_parallel_tool_use": True} + tools_for_request = [Tool(function=f) for f in tools] if tools is not None else None + + # Add tool choice + data["tool_choice"] = tool_choice # Add inner thoughts kwarg + # TODO: Can probably make this more efficient if len(tools_for_request) > 0 and self.llm_config.put_inner_thoughts_in_kwargs: tools_with_inner_thoughts = add_inner_thoughts_to_functions( functions=[t.function.model_dump() for t in tools_for_request], diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index 9fde25cd..022846d5 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -30,10 +30,11 @@ class JobStatus(str, Enum): class MessageStreamStatus(str, Enum): - # done_generation = "[DONE_GEN]" - # done_step = "[DONE_STEP]" done = "[DONE]" + def model_dump_json(self): + return "[DONE]" + class ToolRuleType(str, Enum): """ diff --git a/letta/schemas/letta_message.py b/letta/schemas/letta_message.py index d24e8783..ffaaa8ea 100644 --- a/letta/schemas/letta_message.py +++ b/letta/schemas/letta_message.py @@ -123,9 +123,9 @@ class ToolCall(BaseModel): class ToolCallDelta(BaseModel): - name: Optional[str] - arguments: Optional[str] - tool_call_id: Optional[str] + name: Optional[str] = None + arguments: Optional[str] = None + tool_call_id: Optional[str] = None def model_dump(self, *args, **kwargs): """ diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index 194abb7c..255dcc66 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -1,6 +1,6 @@ from typing import Literal, Optional -from pydantic import BaseModel, ConfigDict, Field, root_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator class LLMConfig(BaseModel): @@ -70,7 +70,8 @@ class LLMConfig(BaseModel): # FIXME hack to silence pydantic protected namespace warning model_config = ConfigDict(protected_namespaces=()) - @root_validator(pre=True) + @model_validator(mode="before") + @classmethod def set_default_put_inner_thoughts(cls, values): """ Dynamically set the default for put_inner_thoughts_in_kwargs based on the model field, @@ -79,15 +80,24 @@ class LLMConfig(BaseModel): model = values.get("model") # Define models where we want put_inner_thoughts_in_kwargs to be False - # For now it is gpt-4 avoid_put_inner_thoughts_in_kwargs = ["gpt-4"] - # Only modify the value if it's None or not provided if values.get("put_inner_thoughts_in_kwargs") is None: values["put_inner_thoughts_in_kwargs"] = False if model in avoid_put_inner_thoughts_in_kwargs else True return values + @model_validator(mode="after") + def validate_reasoning_constraints(self) -> "LLMConfig": + if self.enable_reasoner: + if self.max_reasoning_tokens is None: + raise ValueError("max_reasoning_tokens must be set when enable_reasoner is True") + if self.max_tokens is not None and self.max_reasoning_tokens >= self.max_tokens: + raise ValueError("max_tokens must be greater than max_reasoning_tokens (thinking budget)") + if self.put_inner_thoughts_in_kwargs: + raise ValueError("Extended thinking is not compatible with put_inner_thoughts_in_kwargs") + return self + @classmethod def default_config(cls, model_name: str): """ diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 9914faff..82b596fe 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -8,6 +8,7 @@ from fastapi.responses import JSONResponse from marshmallow import ValidationError from pydantic import Field from sqlalchemy.exc import IntegrityError, OperationalError +from starlette.responses import StreamingResponse from letta.agents.letta_agent import LettaAgent from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG @@ -30,7 +31,6 @@ from letta.schemas.user import User from letta.serialize_schemas.pydantic_agent_schema import AgentSchema from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer -from letta.settings import settings # These can be forward refs, but because Fastapi needs them at runtime the must be imported normally @@ -590,8 +590,10 @@ async def send_message( This endpoint accepts a message from a user and processes it through the agent. """ actor = server.user_manager.get_user_or_default(user_id=actor_id) - if settings.use_experimental: - logger.warning("USING EXPERIMENTAL!") + # TODO: This is redundant, remove soon + agent = server.agent_manager.get_agent_by_id(agent_id, actor) + + if agent.llm_config.model_endpoint_type == "anthropic": experimental_agent = LettaAgent( agent_id=agent_id, message_manager=server.message_manager, @@ -644,17 +646,39 @@ async def send_message_streaming( It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True. """ actor = server.user_manager.get_user_or_default(user_id=actor_id) - result = await server.send_message_to_agent( - agent_id=agent_id, - actor=actor, - messages=request.messages, - stream_steps=True, - stream_tokens=request.stream_tokens, - # Support for AssistantMessage - use_assistant_message=request.use_assistant_message, - assistant_message_tool_name=request.assistant_message_tool_name, - assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, - ) + # TODO: This is redundant, remove soon + agent = server.agent_manager.get_agent_by_id(agent_id, actor) + + if agent.llm_config.model_endpoint_type == "anthropic": + logger.warning("USING EXPERIMENTAL!") + experimental_agent = LettaAgent( + agent_id=agent_id, + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + passage_manager=server.passage_manager, + actor=actor, + ) + + messages = request.messages + content = messages[0].content[0].text if messages and not isinstance(messages[0].content, str) else messages[0].content + result = StreamingResponse( + experimental_agent.step_stream(UserMessage(content=content), max_steps=10, use_assistant_message=request.use_assistant_message), + media_type="text/event-stream", + ) + else: + result = await server.send_message_to_agent( + agent_id=agent_id, + actor=actor, + messages=request.messages, + stream_steps=True, + stream_tokens=request.stream_tokens, + # Support for AssistantMessage + use_assistant_message=request.use_assistant_message, + assistant_message_tool_name=request.assistant_message_tool_name, + assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, + ) + return result @@ -670,31 +694,17 @@ async def process_message_background( ) -> None: """Background task to process the message and update job status.""" try: - # TODO(matt) we should probably make this stream_steps and log each step as it progresses, so the job update GET can see the total steps so far + partial usage? - if settings.use_experimental: - logger.warning("USING EXPERIMENTAL!") - experimental_agent = LettaAgent( - agent_id=agent_id, - message_manager=server.message_manager, - agent_manager=server.agent_manager, - block_manager=server.block_manager, - passage_manager=server.passage_manager, - actor=actor, - ) - content = messages[0].content[0].text if messages and not isinstance(messages[0].content, str) else messages[0].content - result = await experimental_agent.step(UserMessage(content=content), max_steps=10) - else: - result = await server.send_message_to_agent( - agent_id=agent_id, - actor=actor, - messages=messages, - stream_steps=False, # NOTE(matt) - stream_tokens=False, - use_assistant_message=use_assistant_message, - assistant_message_tool_name=assistant_message_tool_name, - assistant_message_tool_kwarg=assistant_message_tool_kwarg, - metadata={"job_id": job_id}, # Pass job_id through metadata - ) + result = await server.send_message_to_agent( + agent_id=agent_id, + actor=actor, + messages=messages, + stream_steps=False, # NOTE(matt) + stream_tokens=False, + use_assistant_message=use_assistant_message, + assistant_message_tool_name=assistant_message_tool_name, + assistant_message_tool_kwarg=assistant_message_tool_kwarg, + metadata={"job_id": job_id}, # Pass job_id through metadata + ) # Update job status to completed job_update = JobUpdate( diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index 72ec07eb..a8aea959 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -18,7 +18,7 @@ from letta.errors import ContextWindowExceededError, RateLimitExceededError from letta.helpers.datetime_helpers import get_utc_time from letta.log import get_logger from letta.schemas.enums import MessageRole -from letta.schemas.letta_message_content import TextContent +from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.message import Message from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User @@ -167,7 +167,7 @@ def create_user_message(input_message: dict, agent_id: str, actor: User) -> Mess return user_message -def create_tool_call_messages_from_openai_response( +def create_letta_messages_from_llm_response( agent_id: str, model: str, function_name: str, @@ -177,6 +177,9 @@ def create_tool_call_messages_from_openai_response( function_response: Optional[str], actor: User, add_heartbeat_request_system_message: bool = False, + reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None, + pre_computed_assistant_message_id: Optional[str] = None, + pre_computed_tool_message_id: Optional[str] = None, ) -> List[Message]: messages = [] @@ -190,9 +193,11 @@ def create_tool_call_messages_from_openai_response( ), type="function", ) + # TODO: Use ToolCallContent instead of tool_calls + # TODO: This helps preserve ordering assistant_message = Message( role=MessageRole.assistant, - content=[], + content=reasoning_content if reasoning_content else [], organization_id=actor.organization_id, agent_id=agent_id, model=model, @@ -200,8 +205,12 @@ def create_tool_call_messages_from_openai_response( tool_call_id=tool_call_id, created_at=get_utc_time(), ) + if pre_computed_assistant_message_id: + assistant_message.id = pre_computed_assistant_message_id messages.append(assistant_message) + # TODO: Use ToolReturnContent instead of TextContent + # TODO: This helps preserve ordering tool_message = Message( role=MessageRole.tool, content=[TextContent(text=package_function_response(function_call_success, function_response))], @@ -212,6 +221,8 @@ def create_tool_call_messages_from_openai_response( tool_call_id=tool_call_id, created_at=get_utc_time(), ) + if pre_computed_tool_message_id: + tool_message.id = pre_computed_tool_message_id messages.append(tool_message) if add_heartbeat_request_system_message: @@ -243,7 +254,7 @@ def create_assistant_messages_from_openai_response( """ tool_call_id = str(uuid.uuid4()) - return create_tool_call_messages_from_openai_response( + return create_letta_messages_from_llm_response( agent_id=agent_id, model=model, function_name=DEFAULT_MESSAGE_TOOL, diff --git a/letta/services/tool_executor/tool_executor.py b/letta/services/tool_executor/tool_executor.py index d4b8f444..6ed7a1ce 100644 --- a/letta/services/tool_executor/tool_executor.py +++ b/letta/services/tool_executor/tool_executor.py @@ -41,6 +41,7 @@ class LettaCoreToolExecutor(ToolExecutor): "send_message": self.send_message, "conversation_search": self.conversation_search, "archival_memory_search": self.archival_memory_search, + "archival_memory_insert": self.archival_memory_insert, } if function_name not in function_map: diff --git a/tests/integration_test_experimental.py b/tests/integration_test_experimental.py index d1d4e486..cd11676e 100644 --- a/tests/integration_test_experimental.py +++ b/tests/integration_test_experimental.py @@ -1,4 +1,3 @@ -import concurrent import os import threading import time @@ -8,7 +7,7 @@ import httpx import openai import pytest from dotenv import load_dotenv -from letta_client import Letta +from letta_client import CreateBlock, Letta from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from letta.agents.letta_agent import LettaAgent @@ -425,23 +424,44 @@ def run_supervisor_worker_group(client: Letta, weather_tool, group_id: str): return response -import concurrent.futures +def test_anthropic_streaming(client: Letta): + agent_name = "anthropic_tester" + existing_agents = client.agents.list(tags=[agent_name]) + for worker in existing_agents: + client.agents.delete(agent_id=worker.id) -def test_multi_agent_broadcast_parallel(client: Letta, weather_tool): - start_time = time.time() - num_groups = 5 + llm_config = LLMConfig( + model="claude-3-7-sonnet-20250219", + model_endpoint_type="anthropic", + model_endpoint="https://api.anthropic.com/v1", + context_window=32000, + handle=f"anthropic/claude-3-5-sonnet-20241022", + put_inner_thoughts_in_kwargs=False, + max_tokens=4096, + enable_reasoner=True, + max_reasoning_tokens=1024, + ) - with concurrent.futures.ThreadPoolExecutor(max_workers=num_groups) as executor: - futures = [] - for i in range(num_groups): - group_id = str(uuid.uuid4())[:8] - futures.append(executor.submit(run_supervisor_worker_group, client, weather_tool, group_id)) + agent = client.agents.create( + name=agent_name, + tags=[agent_name], + include_base_tools=True, + embedding="letta/letta-free", + llm_config=llm_config, + memory_blocks=[CreateBlock(label="human", value="")], + # tool_rules=[InitToolRule(tool_name="core_memory_append")] + ) - results = [f.result() for f in futures] + response = client.agents.messages.create_stream( + agent_id=agent.id, + messages=[ + { + "role": "user", + "content": "Use core memory append to append `banana` to the persona core memory.", + } + ], + stream_tokens=True, + ) - # Optionally: assert something or log runtimes - print(f"Executed {num_groups} supervisor-worker groups in parallel.") - print(f"Total runtime: {time.time() - start_time:.2f} seconds") - for idx, r in enumerate(results): - assert r is not None, f"Group {idx} returned no response" + print(list(response))