feat: Enable Anthropic streaming on new agent loop (#1550)

This commit is contained in:
Matthew Zhou
2025-04-03 19:40:48 -07:00
committed by GitHub
parent b57202f18b
commit 3ba79db859
14 changed files with 652 additions and 129 deletions

View File

@@ -295,6 +295,7 @@ class Agent(BaseAgent):
and not self.supports_structured_output
and len(self.tool_rules_solver.init_tool_rules) > 0
):
# TODO: This just seems wrong? What if there are more than 1 init tool rules?
force_tool_call = self.tool_rules_solver.init_tool_rules[0].tool_name
# Force a tool call if exactly one tool is specified
elif step_count is not None and step_count > 0 and len(allowed_tool_names) == 1:

View File

@@ -1,9 +1,10 @@
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Optional
from typing import Any, AsyncGenerator, Optional, Union
import openai
from letta.schemas.letta_message import UserMessage
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, UserMessage
from letta.schemas.letta_response import LettaResponse
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
@@ -39,9 +40,11 @@ class BaseAgent(ABC):
raise NotImplementedError
@abstractmethod
async def step_stream(self, input_message: UserMessage, max_steps: int = 10) -> AsyncGenerator[str, None]:
async def step_stream(
self, input_message: UserMessage, max_steps: int = 10
) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
"""
Main async execution loop for the agent. Implementations must yield messages as SSE events.
Main streaming execution loop for the agent.
"""
raise NotImplementedError

52
letta/agents/helpers.py Normal file
View File

@@ -0,0 +1,52 @@
from typing import Dict, List, Tuple
from letta.schemas.agent import AgentState
from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import Message
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.utils import create_user_message
from letta.services.message_manager import MessageManager
def _create_letta_response(new_in_context_messages: list[Message], use_assistant_message: bool) -> LettaResponse:
"""
Converts the newly created/persisted messages into a LettaResponse.
"""
response_messages = []
for msg in new_in_context_messages:
response_messages.extend(msg.to_letta_message(use_assistant_message=use_assistant_message))
return LettaResponse(messages=response_messages, usage=LettaUsageStatistics())
def _prepare_in_context_messages(
input_message: Dict, agent_state: AgentState, message_manager: MessageManager, actor: User
) -> Tuple[List[Message], List[Message]]:
"""
Prepares in-context messages for an agent, based on the current state and a new user input.
Args:
input_message (Dict): The new user input message to process.
agent_state (AgentState): The current state of the agent, including message buffer config.
message_manager (MessageManager): The manager used to retrieve and create messages.
actor (User): The user performing the action, used for access control and attribution.
Returns:
Tuple[List[Message], List[Message]]: A tuple containing:
- The current in-context messages (existing context for the agent).
- The new in-context messages (messages created from the new input).
"""
if agent_state.message_buffer_autoclear:
# If autoclear is enabled, only include the most recent system message (usually at index 0)
current_in_context_messages = [message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=actor)[0]]
else:
# Otherwise, include the full list of messages by ID for context
current_in_context_messages = message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=actor)
# Create a new user message from the input and store it
new_in_context_messages = message_manager.create_many_messages(
[create_user_message(input_message=input_message, agent_id=agent_state.id, actor=actor)], actor=actor
)
return current_in_context_messages, new_in_context_messages

View File

@@ -1,27 +1,32 @@
import asyncio
import json
import uuid
from typing import Any, AsyncGenerator, Dict, List, Tuple
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
from openai import AsyncStream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from letta.agents.base_agent import BaseAgent
from letta.constants import DEFAULT_MESSAGE_TOOL
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.tool_execution_helper import enable_strict_mode
from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface
from letta.llm_api.llm_client import LLMClient
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.log import get_logger
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import AssistantMessage
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import Message, MessageUpdate
from letta.schemas.openai.chat_completion_request import UserMessage
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.openai.chat_completion_response import ToolCall
from letta.schemas.user import User
from letta.server.rest_api.utils import create_tool_call_messages_from_openai_response, create_user_message
from letta.server.rest_api.utils import create_letta_messages_from_llm_response
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import compile_system_message
@@ -58,76 +63,130 @@ class LettaAgent(BaseAgent):
async def step(self, input_message: UserMessage, max_steps: int = 10) -> LettaResponse:
input_message = self.pre_process_input_message(input_message)
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
# TODO: Extend to beyond just system message
system_message = [self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor)[0]]
persisted_letta_messages = self.message_manager.create_many_messages(
[create_user_message(input_message=input_message, agent_id=agent_state.id, actor=self.actor)], actor=self.actor
current_in_context_messages, new_in_context_messages = _prepare_in_context_messages(
input_message, agent_state, self.message_manager, self.actor
)
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
# TODO: Note that we do absolutely 0 pulling in of in-context messages here
# TODO: This is specific to B, and needs to be changed
llm_client = LLMClient.create(
llm_config=agent_state.llm_config,
put_inner_thoughts_first=True,
)
for step in range(max_steps):
response = await self._get_ai_reply(
in_context_messages=system_message + persisted_letta_messages,
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
agent_state=agent_state,
tool_rules_solver=tool_rules_solver,
stream=False,
)
persisted_messages, should_continue = await self._handle_ai_response(response, agent_state, tool_rules_solver)
persisted_letta_messages.extend(persisted_messages)
tool_call = response.choices[0].message.tool_calls[0]
persisted_messages, should_continue = await self._handle_ai_response(tool_call, agent_state, tool_rules_solver)
new_in_context_messages.extend(persisted_messages)
if not should_continue:
break
# Persist messages
# Translate to letta response messages
response_messages = []
for message in persisted_letta_messages:
response_messages += message.to_letta_message(use_assistant_message=self.use_assistant_message)
# Extend the in context message ids
if not agent_state.message_buffer_autoclear:
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
return LettaResponse(
messages=response_messages,
# TODO: Actually populate this
usage=LettaUsageStatistics(),
)
return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=self.use_assistant_message)
async def step_stream(self, input_message: UserMessage, max_steps: int = 10) -> AsyncGenerator[str, None]:
@trace_method
async def step_stream(
self, input_message: UserMessage, max_steps: int = 10, use_assistant_message: bool = False
) -> AsyncGenerator[str, None]:
"""
Main streaming loop that yields partial tokens.
Whenever we detect a tool call, we yield from _handle_ai_response as well.
"""
raise NotImplementedError("Not implemented for letta agent")
input_message = self.pre_process_input_message(input_message)
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
current_in_context_messages, new_in_context_messages = _prepare_in_context_messages(
input_message, agent_state, self.message_manager, self.actor
)
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create(
llm_config=agent_state.llm_config,
put_inner_thoughts_first=True,
)
for step in range(max_steps):
stream = await self._get_ai_reply(
llm_client=llm_client,
in_context_messages=current_in_context_messages + new_in_context_messages,
agent_state=agent_state,
tool_rules_solver=tool_rules_solver,
stream=True,
)
# TODO: THIS IS INCREDIBLY UGLY
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
interface = AnthropicStreamingInterface(
use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=llm_client.llm_config.put_inner_thoughts_in_kwargs
)
async for chunk in interface.process(stream):
yield f"data: {chunk.model_dump_json()}\n\n"
# Process resulting stream content
tool_call = interface.get_tool_call_object()
reasoning_content = interface.get_reasoning_content()
persisted_messages, should_continue = await self._handle_ai_response(
tool_call,
agent_state,
tool_rules_solver,
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=interface.letta_assistant_message_id,
pre_computed_tool_message_id=interface.letta_tool_message_id,
)
new_in_context_messages.extend(persisted_messages)
if not should_continue:
break
# Extend the in context message ids
if not agent_state.message_buffer_autoclear:
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
# TODO: Also yield out a letta usage stats SSE
yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n"
@trace_method
async def _get_ai_reply(
self,
llm_client: LLMClientBase,
in_context_messages: List[Message],
agent_state: AgentState,
tool_rules_solver: ToolRulesSolver,
stream: bool,
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
tools = [
t
for t in agent_state.tools
if t.tool_type in {ToolType.CUSTOM}
or (t.tool_type == ToolType.LETTA_CORE and t.name == DEFAULT_MESSAGE_TOOL)
if t.tool_type in {ToolType.CUSTOM, ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE}
or (t.tool_type == ToolType.LETTA_MULTI_AGENT_CORE and t.name == "send_message_to_agents_matching_tags")
]
valid_tool_names = set(tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools])))
allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in valid_tool_names]
valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools]))
# TODO: Copied from legacy agent loop, so please be cautious
# Set force tool
force_tool_call = None
if len(valid_tool_names) == 1:
force_tool_call = valid_tool_names[0]
llm_client = LLMClient.create(
llm_config=agent_state.llm_config,
put_inner_thoughts_first=True,
)
allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
response = await llm_client.send_llm_request_async(
messages=in_context_messages,
tools=allowed_tools,
tool_call=None,
stream=False,
force_tool_call=force_tool_call,
stream=stream,
)
return response
@@ -135,19 +194,18 @@ class LettaAgent(BaseAgent):
@trace_method
async def _handle_ai_response(
self,
chat_completion_response: ChatCompletion,
tool_call: ToolCall,
agent_state: AgentState,
tool_rules_solver: ToolRulesSolver,
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
pre_computed_assistant_message_id: Optional[str] = None,
pre_computed_tool_message_id: Optional[str] = None,
) -> Tuple[List[Message], bool]:
"""
Now that streaming is done, handle the final AI response.
This might yield additional SSE tokens if we do stalling.
At the end, set self._continue_execution accordingly.
"""
# TODO: Some key assumptions here.
# TODO: Assume every call has a tool call, i.e. tool_choice is REQUIRED
tool_call = chat_completion_response.choices[0].message.tool_calls[0]
tool_call_name = tool_call.function.name
tool_call_args_str = tool_call.function.arguments
@@ -158,6 +216,8 @@ class LettaAgent(BaseAgent):
# Get request heartbeats and coerce to bool
request_heartbeat = tool_args.pop("request_heartbeat", False)
# Pre-emptively pop out inner_thoughts
tool_args.pop(INNER_THOUGHTS_KWARG, "")
# So this is necessary, because sometimes non-structured outputs makes mistakes
if not isinstance(request_heartbeat, bool):
@@ -186,7 +246,7 @@ class LettaAgent(BaseAgent):
continue_stepping = True
# 5. Persist to DB
tool_call_messages = create_tool_call_messages_from_openai_response(
tool_call_messages = create_letta_messages_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
function_name=tool_call_name,
@@ -196,6 +256,9 @@ class LettaAgent(BaseAgent):
function_response=tool_result,
actor=self.actor,
add_heartbeat_request_system_message=continue_stepping,
reasoning_content=reasoning_content,
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
pre_computed_tool_message_id=pre_computed_tool_message_id,
)
persisted_messages = self.message_manager.create_many_messages(tool_call_messages, actor=self.actor)

View File

@@ -34,7 +34,7 @@ from letta.schemas.user import User
from letta.server.rest_api.utils import (
convert_letta_messages_to_openai,
create_assistant_messages_from_openai_response,
create_tool_call_messages_from_openai_response,
create_letta_messages_from_llm_response,
create_user_message,
)
from letta.services.agent_manager import AgentManager
@@ -207,7 +207,7 @@ class VoiceAgent(BaseAgent):
in_memory_message_history.append(heartbeat_user_message.model_dump())
# 5. Also store in DB
tool_call_messages = create_tool_call_messages_from_openai_response(
tool_call_messages = create_letta_messages_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
function_name=tool_call_name,

View File

@@ -0,0 +1,323 @@
from datetime import datetime, timezone
from enum import Enum
from typing import AsyncGenerator, List, Union
from anthropic import AsyncStream
from anthropic.types.beta import (
BetaInputJSONDelta,
BetaRawContentBlockDeltaEvent,
BetaRawContentBlockStartEvent,
BetaRawContentBlockStopEvent,
BetaRawMessageDeltaEvent,
BetaRawMessageStartEvent,
BetaRawMessageStopEvent,
BetaRawMessageStreamEvent,
BetaRedactedThinkingBlock,
BetaSignatureDelta,
BetaTextBlock,
BetaTextDelta,
BetaThinkingBlock,
BetaThinkingDelta,
BetaToolUseBlock,
)
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.log import get_logger
from letta.schemas.letta_message import (
AssistantMessage,
HiddenReasoningMessage,
LettaMessage,
ReasoningMessage,
ToolCallDelta,
ToolCallMessage,
)
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
logger = get_logger(__name__)
# TODO: These modes aren't used right now - but can be useful we do multiple sequential tool calling within one Claude message
class EventMode(Enum):
TEXT = "TEXT"
TOOL_USE = "TOOL_USE"
THINKING = "THINKING"
REDACTED_THINKING = "REDACTED_THINKING"
class AnthropicStreamingInterface:
"""
Encapsulates the logic for streaming responses from Anthropic.
This class handles parsing of partial tokens, pre-execution messages,
and detection of tool call events.
"""
def __init__(self, use_assistant_message: bool = False, put_inner_thoughts_in_kwarg: bool = False):
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
self.use_assistant_message = use_assistant_message
# Premake IDs for database writes
self.letta_assistant_message_id = Message.generate_id()
self.letta_tool_message_id = Message.generate_id()
self.anthropic_mode = None
self.message_id = None
self.accumulated_inner_thoughts = []
self.tool_call_id = None
self.tool_call_name = None
self.accumulated_tool_call_args = []
self.previous_parse = {}
# usage trackers
self.input_tokens = 0
self.output_tokens = 0
# reasoning object trackers
self.reasoning_messages = []
# Buffer to hold tool call messages until inner thoughts are complete
self.tool_call_buffer = []
self.inner_thoughts_complete = False
self.put_inner_thoughts_in_kwarg = put_inner_thoughts_in_kwarg
def get_tool_call_object(self) -> ToolCall:
"""Useful for agent loop"""
return ToolCall(
id=self.tool_call_id, function=FunctionCall(arguments="".join(self.accumulated_tool_call_args), name=self.tool_call_name)
)
def _check_inner_thoughts_complete(self, combined_args: str) -> bool:
"""
Check if inner thoughts are complete in the current tool call arguments
by looking for a closing quote after the inner_thoughts field
"""
if not self.put_inner_thoughts_in_kwarg:
# None of the things should have inner thoughts in kwargs
return True
else:
parsed = self.optimistic_json_parser.parse(combined_args)
# TODO: This will break on tools with 0 input
return len(parsed.keys()) > 1 and INNER_THOUGHTS_KWARG in parsed.keys()
async def process(self, stream: AsyncStream[BetaRawMessageStreamEvent]) -> AsyncGenerator[LettaMessage, None]:
async with stream:
async for event in stream:
# TODO: Support BetaThinkingBlock, BetaRedactedThinkingBlock
if isinstance(event, BetaRawContentBlockStartEvent):
content = event.content_block
if isinstance(content, BetaTextBlock):
self.anthropic_mode = EventMode.TEXT
# TODO: Can capture citations, etc.
elif isinstance(content, BetaToolUseBlock):
self.anthropic_mode = EventMode.TOOL_USE
self.tool_call_id = content.id
self.tool_call_name = content.name
self.inner_thoughts_complete = False
if not self.use_assistant_message:
# Buffer the initial tool call message instead of yielding immediately
tool_call_msg = ToolCallMessage(
id=self.letta_tool_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
date=datetime.now(timezone.utc).isoformat(),
)
self.tool_call_buffer.append(tool_call_msg)
elif isinstance(content, BetaThinkingBlock):
self.anthropic_mode = EventMode.THINKING
# TODO: Can capture signature, etc.
elif isinstance(content, BetaRedactedThinkingBlock):
self.anthropic_mode = EventMode.REDACTED_THINKING
hidden_reasoning_message = HiddenReasoningMessage(
id=self.letta_assistant_message_id,
state="redacted",
hidden_reasoning=content.data,
date=datetime.now(timezone.utc).isoformat(),
)
self.reasoning_messages.append(hidden_reasoning_message)
yield hidden_reasoning_message
elif isinstance(event, BetaRawContentBlockDeltaEvent):
delta = event.delta
if isinstance(delta, BetaTextDelta):
# Safety check
if not self.anthropic_mode == EventMode.TEXT:
raise RuntimeError(
f"Streaming integrity failed - received BetaTextDelta object while not in TEXT EventMode: {delta}"
)
# TODO: Strip out </thinking> more robustly, this is pretty hacky lol
delta.text = delta.text.replace("</thinking>", "")
self.accumulated_inner_thoughts.append(delta.text)
reasoning_message = ReasoningMessage(
id=self.letta_assistant_message_id,
reasoning=self.accumulated_inner_thoughts[-1],
date=datetime.now(timezone.utc).isoformat(),
)
self.reasoning_messages.append(reasoning_message)
yield reasoning_message
elif isinstance(delta, BetaInputJSONDelta):
if not self.anthropic_mode == EventMode.TOOL_USE:
raise RuntimeError(
f"Streaming integrity failed - received BetaInputJSONDelta object while not in TOOL_USE EventMode: {delta}"
)
self.accumulated_tool_call_args.append(delta.partial_json)
combined_args = "".join(self.accumulated_tool_call_args)
current_parsed = self.optimistic_json_parser.parse(combined_args)
# Start detecting a difference in inner thoughts
previous_inner_thoughts = self.previous_parse.get(INNER_THOUGHTS_KWARG, "")
current_inner_thoughts = current_parsed.get(INNER_THOUGHTS_KWARG, "")
inner_thoughts_diff = current_inner_thoughts[len(previous_inner_thoughts) :]
if inner_thoughts_diff:
reasoning_message = ReasoningMessage(
id=self.letta_assistant_message_id,
reasoning=inner_thoughts_diff,
date=datetime.now(timezone.utc).isoformat(),
)
self.reasoning_messages.append(reasoning_message)
yield reasoning_message
# Check if inner thoughts are complete - if so, flush the buffer
if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(combined_args):
self.inner_thoughts_complete = True
# Flush all buffered tool call messages
for buffered_msg in self.tool_call_buffer:
yield buffered_msg
self.tool_call_buffer = []
# Start detecting special case of "send_message"
if self.tool_call_name == DEFAULT_MESSAGE_TOOL and self.use_assistant_message:
previous_send_message = self.previous_parse.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
current_send_message = current_parsed.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
send_message_diff = current_send_message[len(previous_send_message) :]
# Only stream out if it's not an empty string
if send_message_diff:
yield AssistantMessage(
id=self.letta_assistant_message_id,
content=[TextContent(text=send_message_diff)],
date=datetime.now(timezone.utc).isoformat(),
)
else:
# Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status
tool_call_msg = ToolCallMessage(
id=self.letta_tool_message_id,
tool_call=ToolCallDelta(arguments=delta.partial_json),
date=datetime.now(timezone.utc).isoformat(),
)
if self.inner_thoughts_complete:
yield tool_call_msg
else:
self.tool_call_buffer.append(tool_call_msg)
# Set previous parse
self.previous_parse = current_parsed
elif isinstance(delta, BetaThinkingDelta):
# Safety check
if not self.anthropic_mode == EventMode.THINKING:
raise RuntimeError(
f"Streaming integrity failed - received BetaThinkingBlock object while not in THINKING EventMode: {delta}"
)
reasoning_message = ReasoningMessage(
id=self.letta_assistant_message_id,
source="reasoner_model",
reasoning=delta.thinking,
date=datetime.now(timezone.utc).isoformat(),
)
self.reasoning_messages.append(reasoning_message)
yield reasoning_message
elif isinstance(delta, BetaSignatureDelta):
# Safety check
if not self.anthropic_mode == EventMode.THINKING:
raise RuntimeError(
f"Streaming integrity failed - received BetaSignatureDelta object while not in THINKING EventMode: {delta}"
)
reasoning_message = ReasoningMessage(
id=self.letta_assistant_message_id,
source="reasoner_model",
reasoning="",
date=datetime.now(timezone.utc).isoformat(),
signature=delta.signature,
)
self.reasoning_messages.append(reasoning_message)
yield reasoning_message
elif isinstance(event, BetaRawMessageStartEvent):
self.message_id = event.message.id
self.input_tokens += event.message.usage.input_tokens
self.output_tokens += event.message.usage.output_tokens
elif isinstance(event, BetaRawMessageDeltaEvent):
self.output_tokens += event.usage.output_tokens
elif isinstance(event, BetaRawMessageStopEvent):
# Don't do anything here! We don't want to stop the stream.
pass
elif isinstance(event, BetaRawContentBlockStopEvent):
# If we're exiting a tool use block and there are still buffered messages,
# we should flush them now
if self.anthropic_mode == EventMode.TOOL_USE and self.tool_call_buffer:
for buffered_msg in self.tool_call_buffer:
yield buffered_msg
self.tool_call_buffer = []
self.anthropic_mode = None
def get_reasoning_content(self) -> List[Union[TextContent, ReasoningContent, RedactedReasoningContent]]:
def _process_group(
group: List[Union[ReasoningMessage, HiddenReasoningMessage]], group_type: str
) -> Union[TextContent, ReasoningContent, RedactedReasoningContent]:
if group_type == "reasoning":
reasoning_text = "".join(chunk.reasoning for chunk in group)
is_native = any(chunk.source == "reasoner_model" for chunk in group)
signature = next((chunk.signature for chunk in group if chunk.signature is not None), None)
if is_native:
return ReasoningContent(is_native=is_native, reasoning=reasoning_text, signature=signature)
else:
return TextContent(text=reasoning_text)
elif group_type == "redacted":
redacted_text = "".join(chunk.hidden_reasoning for chunk in group if chunk.hidden_reasoning is not None)
return RedactedReasoningContent(data=redacted_text)
else:
raise ValueError("Unexpected group type")
merged = []
current_group = []
current_group_type = None # "reasoning" or "redacted"
for msg in self.reasoning_messages:
# Determine the type of the current message
if isinstance(msg, HiddenReasoningMessage):
msg_type = "redacted"
elif isinstance(msg, ReasoningMessage):
msg_type = "reasoning"
else:
raise ValueError("Unexpected message type")
# Initialize group type if not set
if current_group_type is None:
current_group_type = msg_type
# If the type changes, process the current group
if msg_type != current_group_type:
merged.append(_process_group(current_group, current_group_type))
current_group = []
current_group_type = msg_type
current_group.append(msg)
# Process the final group, if any.
if current_group:
merged.append(_process_group(current_group, current_group_type))
return merged

View File

@@ -3,7 +3,9 @@ import re
from typing import List, Optional, Union
import anthropic
from anthropic import AsyncStream
from anthropic.types import Message as AnthropicMessage
from anthropic.types.beta import BetaRawMessageStreamEvent
from letta.errors import (
ContextWindowExceededError,
@@ -28,6 +30,7 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.openai.chat_completion_response import Message as ChoiceMessage
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
from letta.services.provider_manager import ProviderManager
from letta.tracing import trace_method
DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence."
@@ -46,18 +49,28 @@ class AnthropicClient(LLMClientBase):
response = await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
return response.model_dump()
@trace_method
async def stream_async(self, request_data: dict) -> AsyncStream[BetaRawMessageStreamEvent]:
client = self._get_anthropic_client(async_client=True)
request_data["stream"] = True
return await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
@trace_method
def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
override_key = ProviderManager().get_anthropic_override_key()
if async_client:
return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic()
return anthropic.Anthropic(api_key=override_key) if override_key else anthropic.Anthropic()
@trace_method
def build_request_data(
self,
messages: List[PydanticMessage],
tools: List[dict],
force_tool_call: Optional[str] = None,
) -> dict:
# TODO: This needs to get cleaned up. The logic here is pretty confusing.
# TODO: I really want to get rid of prefixing, it's a recipe for disaster code maintenance wise
prefix_fill = True
if not self.use_tool_naming:
raise NotImplementedError("Only tool calling supported on Anthropic API requests")
@@ -73,11 +86,6 @@ class AnthropicClient(LLMClientBase):
# Extended Thinking
if self.llm_config.enable_reasoner:
assert (
self.llm_config.max_reasoning_tokens is not None and self.llm_config.max_reasoning_tokens < self.llm_config.max_tokens
), "max tokens must be greater than thinking budget"
assert not self.llm_config.put_inner_thoughts_in_kwargs, "extended thinking not compatible with put_inner_thoughts_in_kwargs"
data["thinking"] = {
"type": "enabled",
"budget_tokens": self.llm_config.max_reasoning_tokens,
@@ -89,15 +97,35 @@ class AnthropicClient(LLMClientBase):
prefix_fill = False
# Tools
tools_for_request = (
[Tool(function=f) for f in tools if f["name"] == force_tool_call]
if force_tool_call is not None
else [Tool(function=f) for f in tools]
)
if force_tool_call is not None:
self.llm_config.put_inner_thoughts_in_kwargs = True # why do we do this ?
# For an overview on tool choice:
# https://docs.anthropic.com/en/docs/build-with-claude/tool-use/overview
if not tools:
# Special case for summarization path
tools_for_request = None
tool_choice = None
elif force_tool_call is not None:
tool_choice = {"type": "tool", "name": force_tool_call}
tools_for_request = [Tool(function=f) for f in tools if f["name"] == force_tool_call]
# need to have this setting to be able to put inner thoughts in kwargs
if not self.llm_config.put_inner_thoughts_in_kwargs:
logger.warning(
f"Force setting put_inner_thoughts_in_kwargs to True for Claude because there is a forced tool call: {force_tool_call}"
)
self.llm_config.put_inner_thoughts_in_kwargs = True
else:
if self.llm_config.put_inner_thoughts_in_kwargs:
# tool_choice_type other than "auto" only plays nice if thinking goes inside the tool calls
tool_choice = {"type": "any", "disable_parallel_tool_use": True}
else:
tool_choice = {"type": "auto", "disable_parallel_tool_use": True}
tools_for_request = [Tool(function=f) for f in tools] if tools is not None else None
# Add tool choice
data["tool_choice"] = tool_choice
# Add inner thoughts kwarg
# TODO: Can probably make this more efficient
if len(tools_for_request) > 0 and self.llm_config.put_inner_thoughts_in_kwargs:
tools_with_inner_thoughts = add_inner_thoughts_to_functions(
functions=[t.function.model_dump() for t in tools_for_request],

View File

@@ -30,10 +30,11 @@ class JobStatus(str, Enum):
class MessageStreamStatus(str, Enum):
# done_generation = "[DONE_GEN]"
# done_step = "[DONE_STEP]"
done = "[DONE]"
def model_dump_json(self):
return "[DONE]"
class ToolRuleType(str, Enum):
"""

View File

@@ -123,9 +123,9 @@ class ToolCall(BaseModel):
class ToolCallDelta(BaseModel):
name: Optional[str]
arguments: Optional[str]
tool_call_id: Optional[str]
name: Optional[str] = None
arguments: Optional[str] = None
tool_call_id: Optional[str] = None
def model_dump(self, *args, **kwargs):
"""

View File

@@ -1,6 +1,6 @@
from typing import Literal, Optional
from pydantic import BaseModel, ConfigDict, Field, root_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator
class LLMConfig(BaseModel):
@@ -70,7 +70,8 @@ class LLMConfig(BaseModel):
# FIXME hack to silence pydantic protected namespace warning
model_config = ConfigDict(protected_namespaces=())
@root_validator(pre=True)
@model_validator(mode="before")
@classmethod
def set_default_put_inner_thoughts(cls, values):
"""
Dynamically set the default for put_inner_thoughts_in_kwargs based on the model field,
@@ -79,15 +80,24 @@ class LLMConfig(BaseModel):
model = values.get("model")
# Define models where we want put_inner_thoughts_in_kwargs to be False
# For now it is gpt-4
avoid_put_inner_thoughts_in_kwargs = ["gpt-4"]
# Only modify the value if it's None or not provided
if values.get("put_inner_thoughts_in_kwargs") is None:
values["put_inner_thoughts_in_kwargs"] = False if model in avoid_put_inner_thoughts_in_kwargs else True
return values
@model_validator(mode="after")
def validate_reasoning_constraints(self) -> "LLMConfig":
if self.enable_reasoner:
if self.max_reasoning_tokens is None:
raise ValueError("max_reasoning_tokens must be set when enable_reasoner is True")
if self.max_tokens is not None and self.max_reasoning_tokens >= self.max_tokens:
raise ValueError("max_tokens must be greater than max_reasoning_tokens (thinking budget)")
if self.put_inner_thoughts_in_kwargs:
raise ValueError("Extended thinking is not compatible with put_inner_thoughts_in_kwargs")
return self
@classmethod
def default_config(cls, model_name: str):
"""

View File

@@ -8,6 +8,7 @@ from fastapi.responses import JSONResponse
from marshmallow import ValidationError
from pydantic import Field
from sqlalchemy.exc import IntegrityError, OperationalError
from starlette.responses import StreamingResponse
from letta.agents.letta_agent import LettaAgent
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
@@ -30,7 +31,6 @@ from letta.schemas.user import User
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
from letta.settings import settings
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally
@@ -590,8 +590,10 @@ async def send_message(
This endpoint accepts a message from a user and processes it through the agent.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
if settings.use_experimental:
logger.warning("USING EXPERIMENTAL!")
# TODO: This is redundant, remove soon
agent = server.agent_manager.get_agent_by_id(agent_id, actor)
if agent.llm_config.model_endpoint_type == "anthropic":
experimental_agent = LettaAgent(
agent_id=agent_id,
message_manager=server.message_manager,
@@ -644,17 +646,39 @@ async def send_message_streaming(
It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
result = await server.send_message_to_agent(
agent_id=agent_id,
actor=actor,
messages=request.messages,
stream_steps=True,
stream_tokens=request.stream_tokens,
# Support for AssistantMessage
use_assistant_message=request.use_assistant_message,
assistant_message_tool_name=request.assistant_message_tool_name,
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
)
# TODO: This is redundant, remove soon
agent = server.agent_manager.get_agent_by_id(agent_id, actor)
if agent.llm_config.model_endpoint_type == "anthropic":
logger.warning("USING EXPERIMENTAL!")
experimental_agent = LettaAgent(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
actor=actor,
)
messages = request.messages
content = messages[0].content[0].text if messages and not isinstance(messages[0].content, str) else messages[0].content
result = StreamingResponse(
experimental_agent.step_stream(UserMessage(content=content), max_steps=10, use_assistant_message=request.use_assistant_message),
media_type="text/event-stream",
)
else:
result = await server.send_message_to_agent(
agent_id=agent_id,
actor=actor,
messages=request.messages,
stream_steps=True,
stream_tokens=request.stream_tokens,
# Support for AssistantMessage
use_assistant_message=request.use_assistant_message,
assistant_message_tool_name=request.assistant_message_tool_name,
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
)
return result
@@ -670,31 +694,17 @@ async def process_message_background(
) -> None:
"""Background task to process the message and update job status."""
try:
# TODO(matt) we should probably make this stream_steps and log each step as it progresses, so the job update GET can see the total steps so far + partial usage?
if settings.use_experimental:
logger.warning("USING EXPERIMENTAL!")
experimental_agent = LettaAgent(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
actor=actor,
)
content = messages[0].content[0].text if messages and not isinstance(messages[0].content, str) else messages[0].content
result = await experimental_agent.step(UserMessage(content=content), max_steps=10)
else:
result = await server.send_message_to_agent(
agent_id=agent_id,
actor=actor,
messages=messages,
stream_steps=False, # NOTE(matt)
stream_tokens=False,
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
metadata={"job_id": job_id}, # Pass job_id through metadata
)
result = await server.send_message_to_agent(
agent_id=agent_id,
actor=actor,
messages=messages,
stream_steps=False, # NOTE(matt)
stream_tokens=False,
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
metadata={"job_id": job_id}, # Pass job_id through metadata
)
# Update job status to completed
job_update = JobUpdate(

View File

@@ -18,7 +18,7 @@ from letta.errors import ContextWindowExceededError, RateLimitExceededError
from letta.helpers.datetime_helpers import get_utc_time
from letta.log import get_logger
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.message import Message
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
@@ -167,7 +167,7 @@ def create_user_message(input_message: dict, agent_id: str, actor: User) -> Mess
return user_message
def create_tool_call_messages_from_openai_response(
def create_letta_messages_from_llm_response(
agent_id: str,
model: str,
function_name: str,
@@ -177,6 +177,9 @@ def create_tool_call_messages_from_openai_response(
function_response: Optional[str],
actor: User,
add_heartbeat_request_system_message: bool = False,
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
pre_computed_assistant_message_id: Optional[str] = None,
pre_computed_tool_message_id: Optional[str] = None,
) -> List[Message]:
messages = []
@@ -190,9 +193,11 @@ def create_tool_call_messages_from_openai_response(
),
type="function",
)
# TODO: Use ToolCallContent instead of tool_calls
# TODO: This helps preserve ordering
assistant_message = Message(
role=MessageRole.assistant,
content=[],
content=reasoning_content if reasoning_content else [],
organization_id=actor.organization_id,
agent_id=agent_id,
model=model,
@@ -200,8 +205,12 @@ def create_tool_call_messages_from_openai_response(
tool_call_id=tool_call_id,
created_at=get_utc_time(),
)
if pre_computed_assistant_message_id:
assistant_message.id = pre_computed_assistant_message_id
messages.append(assistant_message)
# TODO: Use ToolReturnContent instead of TextContent
# TODO: This helps preserve ordering
tool_message = Message(
role=MessageRole.tool,
content=[TextContent(text=package_function_response(function_call_success, function_response))],
@@ -212,6 +221,8 @@ def create_tool_call_messages_from_openai_response(
tool_call_id=tool_call_id,
created_at=get_utc_time(),
)
if pre_computed_tool_message_id:
tool_message.id = pre_computed_tool_message_id
messages.append(tool_message)
if add_heartbeat_request_system_message:
@@ -243,7 +254,7 @@ def create_assistant_messages_from_openai_response(
"""
tool_call_id = str(uuid.uuid4())
return create_tool_call_messages_from_openai_response(
return create_letta_messages_from_llm_response(
agent_id=agent_id,
model=model,
function_name=DEFAULT_MESSAGE_TOOL,

View File

@@ -41,6 +41,7 @@ class LettaCoreToolExecutor(ToolExecutor):
"send_message": self.send_message,
"conversation_search": self.conversation_search,
"archival_memory_search": self.archival_memory_search,
"archival_memory_insert": self.archival_memory_insert,
}
if function_name not in function_map:

View File

@@ -1,4 +1,3 @@
import concurrent
import os
import threading
import time
@@ -8,7 +7,7 @@ import httpx
import openai
import pytest
from dotenv import load_dotenv
from letta_client import Letta
from letta_client import CreateBlock, Letta
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.agents.letta_agent import LettaAgent
@@ -425,23 +424,44 @@ def run_supervisor_worker_group(client: Letta, weather_tool, group_id: str):
return response
import concurrent.futures
def test_anthropic_streaming(client: Letta):
agent_name = "anthropic_tester"
existing_agents = client.agents.list(tags=[agent_name])
for worker in existing_agents:
client.agents.delete(agent_id=worker.id)
def test_multi_agent_broadcast_parallel(client: Letta, weather_tool):
start_time = time.time()
num_groups = 5
llm_config = LLMConfig(
model="claude-3-7-sonnet-20250219",
model_endpoint_type="anthropic",
model_endpoint="https://api.anthropic.com/v1",
context_window=32000,
handle=f"anthropic/claude-3-5-sonnet-20241022",
put_inner_thoughts_in_kwargs=False,
max_tokens=4096,
enable_reasoner=True,
max_reasoning_tokens=1024,
)
with concurrent.futures.ThreadPoolExecutor(max_workers=num_groups) as executor:
futures = []
for i in range(num_groups):
group_id = str(uuid.uuid4())[:8]
futures.append(executor.submit(run_supervisor_worker_group, client, weather_tool, group_id))
agent = client.agents.create(
name=agent_name,
tags=[agent_name],
include_base_tools=True,
embedding="letta/letta-free",
llm_config=llm_config,
memory_blocks=[CreateBlock(label="human", value="")],
# tool_rules=[InitToolRule(tool_name="core_memory_append")]
)
results = [f.result() for f in futures]
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=[
{
"role": "user",
"content": "Use core memory append to append `banana` to the persona core memory.",
}
],
stream_tokens=True,
)
# Optionally: assert something or log runtimes
print(f"Executed {num_groups} supervisor-worker groups in parallel.")
print(f"Total runtime: {time.time() - start_time:.2f} seconds")
for idx, r in enumerate(results):
assert r is not None, f"Group {idx} returned no response"
print(list(response))