feat: Enable Anthropic streaming on new agent loop (#1550)
This commit is contained in:
@@ -295,6 +295,7 @@ class Agent(BaseAgent):
|
||||
and not self.supports_structured_output
|
||||
and len(self.tool_rules_solver.init_tool_rules) > 0
|
||||
):
|
||||
# TODO: This just seems wrong? What if there are more than 1 init tool rules?
|
||||
force_tool_call = self.tool_rules_solver.init_tool_rules[0].tool_name
|
||||
# Force a tool call if exactly one tool is specified
|
||||
elif step_count is not None and step_count > 0 and len(allowed_tool_names) == 1:
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Optional
|
||||
from typing import Any, AsyncGenerator, Optional, Union
|
||||
|
||||
import openai
|
||||
|
||||
from letta.schemas.letta_message import UserMessage
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, UserMessage
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
@@ -39,9 +40,11 @@ class BaseAgent(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def step_stream(self, input_message: UserMessage, max_steps: int = 10) -> AsyncGenerator[str, None]:
|
||||
async def step_stream(
|
||||
self, input_message: UserMessage, max_steps: int = 10
|
||||
) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
|
||||
"""
|
||||
Main async execution loop for the agent. Implementations must yield messages as SSE events.
|
||||
Main streaming execution loop for the agent.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
52
letta/agents/helpers.py
Normal file
52
letta/agents/helpers.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import create_user_message
|
||||
from letta.services.message_manager import MessageManager
|
||||
|
||||
|
||||
def _create_letta_response(new_in_context_messages: list[Message], use_assistant_message: bool) -> LettaResponse:
|
||||
"""
|
||||
Converts the newly created/persisted messages into a LettaResponse.
|
||||
"""
|
||||
response_messages = []
|
||||
for msg in new_in_context_messages:
|
||||
response_messages.extend(msg.to_letta_message(use_assistant_message=use_assistant_message))
|
||||
return LettaResponse(messages=response_messages, usage=LettaUsageStatistics())
|
||||
|
||||
|
||||
def _prepare_in_context_messages(
|
||||
input_message: Dict, agent_state: AgentState, message_manager: MessageManager, actor: User
|
||||
) -> Tuple[List[Message], List[Message]]:
|
||||
"""
|
||||
Prepares in-context messages for an agent, based on the current state and a new user input.
|
||||
|
||||
Args:
|
||||
input_message (Dict): The new user input message to process.
|
||||
agent_state (AgentState): The current state of the agent, including message buffer config.
|
||||
message_manager (MessageManager): The manager used to retrieve and create messages.
|
||||
actor (User): The user performing the action, used for access control and attribution.
|
||||
|
||||
Returns:
|
||||
Tuple[List[Message], List[Message]]: A tuple containing:
|
||||
- The current in-context messages (existing context for the agent).
|
||||
- The new in-context messages (messages created from the new input).
|
||||
"""
|
||||
|
||||
if agent_state.message_buffer_autoclear:
|
||||
# If autoclear is enabled, only include the most recent system message (usually at index 0)
|
||||
current_in_context_messages = [message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=actor)[0]]
|
||||
else:
|
||||
# Otherwise, include the full list of messages by ID for context
|
||||
current_in_context_messages = message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=actor)
|
||||
|
||||
# Create a new user message from the input and store it
|
||||
new_in_context_messages = message_manager.create_many_messages(
|
||||
[create_user_message(input_message=input_message, agent_id=agent_state.id, actor=actor)], actor=actor
|
||||
)
|
||||
|
||||
return current_in_context_messages, new_in_context_messages
|
||||
@@ -1,27 +1,32 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Tuple
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL
|
||||
from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.tool_execution_helper import enable_strict_mode
|
||||
from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from letta.log import get_logger
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.letta_message import AssistantMessage
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.message import Message, MessageUpdate
|
||||
from letta.schemas.openai.chat_completion_request import UserMessage
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.openai.chat_completion_response import ToolCall
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import create_tool_call_messages_from_openai_response, create_user_message
|
||||
from letta.server.rest_api.utils import create_letta_messages_from_llm_response
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import compile_system_message
|
||||
@@ -58,76 +63,130 @@ class LettaAgent(BaseAgent):
|
||||
async def step(self, input_message: UserMessage, max_steps: int = 10) -> LettaResponse:
|
||||
input_message = self.pre_process_input_message(input_message)
|
||||
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
|
||||
# TODO: Extend to beyond just system message
|
||||
system_message = [self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor)[0]]
|
||||
persisted_letta_messages = self.message_manager.create_many_messages(
|
||||
[create_user_message(input_message=input_message, agent_id=agent_state.id, actor=self.actor)], actor=self.actor
|
||||
current_in_context_messages, new_in_context_messages = _prepare_in_context_messages(
|
||||
input_message, agent_state, self.message_manager, self.actor
|
||||
)
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
|
||||
# TODO: Note that we do absolutely 0 pulling in of in-context messages here
|
||||
# TODO: This is specific to B, and needs to be changed
|
||||
llm_client = LLMClient.create(
|
||||
llm_config=agent_state.llm_config,
|
||||
put_inner_thoughts_first=True,
|
||||
)
|
||||
for step in range(max_steps):
|
||||
response = await self._get_ai_reply(
|
||||
in_context_messages=system_message + persisted_letta_messages,
|
||||
llm_client=llm_client,
|
||||
in_context_messages=current_in_context_messages + new_in_context_messages,
|
||||
agent_state=agent_state,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
stream=False,
|
||||
)
|
||||
persisted_messages, should_continue = await self._handle_ai_response(response, agent_state, tool_rules_solver)
|
||||
persisted_letta_messages.extend(persisted_messages)
|
||||
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
persisted_messages, should_continue = await self._handle_ai_response(tool_call, agent_state, tool_rules_solver)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
|
||||
if not should_continue:
|
||||
break
|
||||
|
||||
# Persist messages
|
||||
# Translate to letta response messages
|
||||
response_messages = []
|
||||
for message in persisted_letta_messages:
|
||||
response_messages += message.to_letta_message(use_assistant_message=self.use_assistant_message)
|
||||
# Extend the in context message ids
|
||||
if not agent_state.message_buffer_autoclear:
|
||||
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
|
||||
self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
|
||||
|
||||
return LettaResponse(
|
||||
messages=response_messages,
|
||||
# TODO: Actually populate this
|
||||
usage=LettaUsageStatistics(),
|
||||
)
|
||||
return _create_letta_response(new_in_context_messages=new_in_context_messages, use_assistant_message=self.use_assistant_message)
|
||||
|
||||
async def step_stream(self, input_message: UserMessage, max_steps: int = 10) -> AsyncGenerator[str, None]:
|
||||
@trace_method
|
||||
async def step_stream(
|
||||
self, input_message: UserMessage, max_steps: int = 10, use_assistant_message: bool = False
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Main streaming loop that yields partial tokens.
|
||||
Whenever we detect a tool call, we yield from _handle_ai_response as well.
|
||||
"""
|
||||
raise NotImplementedError("Not implemented for letta agent")
|
||||
input_message = self.pre_process_input_message(input_message)
|
||||
agent_state = self.agent_manager.get_agent_by_id(self.agent_id, actor=self.actor)
|
||||
current_in_context_messages, new_in_context_messages = _prepare_in_context_messages(
|
||||
input_message, agent_state, self.message_manager, self.actor
|
||||
)
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
llm_client = LLMClient.create(
|
||||
llm_config=agent_state.llm_config,
|
||||
put_inner_thoughts_first=True,
|
||||
)
|
||||
|
||||
for step in range(max_steps):
|
||||
stream = await self._get_ai_reply(
|
||||
llm_client=llm_client,
|
||||
in_context_messages=current_in_context_messages + new_in_context_messages,
|
||||
agent_state=agent_state,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# TODO: THIS IS INCREDIBLY UGLY
|
||||
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
|
||||
interface = AnthropicStreamingInterface(
|
||||
use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=llm_client.llm_config.put_inner_thoughts_in_kwargs
|
||||
)
|
||||
async for chunk in interface.process(stream):
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Process resulting stream content
|
||||
tool_call = interface.get_tool_call_object()
|
||||
reasoning_content = interface.get_reasoning_content()
|
||||
persisted_messages, should_continue = await self._handle_ai_response(
|
||||
tool_call,
|
||||
agent_state,
|
||||
tool_rules_solver,
|
||||
reasoning_content=reasoning_content,
|
||||
pre_computed_assistant_message_id=interface.letta_assistant_message_id,
|
||||
pre_computed_tool_message_id=interface.letta_tool_message_id,
|
||||
)
|
||||
new_in_context_messages.extend(persisted_messages)
|
||||
|
||||
if not should_continue:
|
||||
break
|
||||
|
||||
# Extend the in context message ids
|
||||
if not agent_state.message_buffer_autoclear:
|
||||
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
|
||||
self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
|
||||
|
||||
# TODO: Also yield out a letta usage stats SSE
|
||||
|
||||
yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n"
|
||||
|
||||
@trace_method
|
||||
async def _get_ai_reply(
|
||||
self,
|
||||
llm_client: LLMClientBase,
|
||||
in_context_messages: List[Message],
|
||||
agent_state: AgentState,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
stream: bool,
|
||||
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
||||
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
|
||||
|
||||
tools = [
|
||||
t
|
||||
for t in agent_state.tools
|
||||
if t.tool_type in {ToolType.CUSTOM}
|
||||
or (t.tool_type == ToolType.LETTA_CORE and t.name == DEFAULT_MESSAGE_TOOL)
|
||||
if t.tool_type in {ToolType.CUSTOM, ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE}
|
||||
or (t.tool_type == ToolType.LETTA_MULTI_AGENT_CORE and t.name == "send_message_to_agents_matching_tags")
|
||||
]
|
||||
|
||||
valid_tool_names = set(tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools])))
|
||||
allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in valid_tool_names]
|
||||
valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools]))
|
||||
# TODO: Copied from legacy agent loop, so please be cautious
|
||||
# Set force tool
|
||||
force_tool_call = None
|
||||
if len(valid_tool_names) == 1:
|
||||
force_tool_call = valid_tool_names[0]
|
||||
|
||||
llm_client = LLMClient.create(
|
||||
llm_config=agent_state.llm_config,
|
||||
put_inner_thoughts_first=True,
|
||||
)
|
||||
allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
|
||||
|
||||
response = await llm_client.send_llm_request_async(
|
||||
messages=in_context_messages,
|
||||
tools=allowed_tools,
|
||||
tool_call=None,
|
||||
stream=False,
|
||||
force_tool_call=force_tool_call,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -135,19 +194,18 @@ class LettaAgent(BaseAgent):
|
||||
@trace_method
|
||||
async def _handle_ai_response(
|
||||
self,
|
||||
chat_completion_response: ChatCompletion,
|
||||
tool_call: ToolCall,
|
||||
agent_state: AgentState,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
|
||||
pre_computed_assistant_message_id: Optional[str] = None,
|
||||
pre_computed_tool_message_id: Optional[str] = None,
|
||||
) -> Tuple[List[Message], bool]:
|
||||
"""
|
||||
Now that streaming is done, handle the final AI response.
|
||||
This might yield additional SSE tokens if we do stalling.
|
||||
At the end, set self._continue_execution accordingly.
|
||||
"""
|
||||
# TODO: Some key assumptions here.
|
||||
# TODO: Assume every call has a tool call, i.e. tool_choice is REQUIRED
|
||||
tool_call = chat_completion_response.choices[0].message.tool_calls[0]
|
||||
|
||||
tool_call_name = tool_call.function.name
|
||||
tool_call_args_str = tool_call.function.arguments
|
||||
|
||||
@@ -158,6 +216,8 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
# Get request heartbeats and coerce to bool
|
||||
request_heartbeat = tool_args.pop("request_heartbeat", False)
|
||||
# Pre-emptively pop out inner_thoughts
|
||||
tool_args.pop(INNER_THOUGHTS_KWARG, "")
|
||||
|
||||
# So this is necessary, because sometimes non-structured outputs makes mistakes
|
||||
if not isinstance(request_heartbeat, bool):
|
||||
@@ -186,7 +246,7 @@ class LettaAgent(BaseAgent):
|
||||
continue_stepping = True
|
||||
|
||||
# 5. Persist to DB
|
||||
tool_call_messages = create_tool_call_messages_from_openai_response(
|
||||
tool_call_messages = create_letta_messages_from_llm_response(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
function_name=tool_call_name,
|
||||
@@ -196,6 +256,9 @@ class LettaAgent(BaseAgent):
|
||||
function_response=tool_result,
|
||||
actor=self.actor,
|
||||
add_heartbeat_request_system_message=continue_stepping,
|
||||
reasoning_content=reasoning_content,
|
||||
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
|
||||
pre_computed_tool_message_id=pre_computed_tool_message_id,
|
||||
)
|
||||
persisted_messages = self.message_manager.create_many_messages(tool_call_messages, actor=self.actor)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import (
|
||||
convert_letta_messages_to_openai,
|
||||
create_assistant_messages_from_openai_response,
|
||||
create_tool_call_messages_from_openai_response,
|
||||
create_letta_messages_from_llm_response,
|
||||
create_user_message,
|
||||
)
|
||||
from letta.services.agent_manager import AgentManager
|
||||
@@ -207,7 +207,7 @@ class VoiceAgent(BaseAgent):
|
||||
in_memory_message_history.append(heartbeat_user_message.model_dump())
|
||||
|
||||
# 5. Also store in DB
|
||||
tool_call_messages = create_tool_call_messages_from_openai_response(
|
||||
tool_call_messages = create_letta_messages_from_llm_response(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
function_name=tool_call_name,
|
||||
|
||||
323
letta/interfaces/anthropic_streaming_interface.py
Normal file
323
letta/interfaces/anthropic_streaming_interface.py
Normal file
@@ -0,0 +1,323 @@
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import AsyncGenerator, List, Union
|
||||
|
||||
from anthropic import AsyncStream
|
||||
from anthropic.types.beta import (
|
||||
BetaInputJSONDelta,
|
||||
BetaRawContentBlockDeltaEvent,
|
||||
BetaRawContentBlockStartEvent,
|
||||
BetaRawContentBlockStopEvent,
|
||||
BetaRawMessageDeltaEvent,
|
||||
BetaRawMessageStartEvent,
|
||||
BetaRawMessageStopEvent,
|
||||
BetaRawMessageStreamEvent,
|
||||
BetaRedactedThinkingBlock,
|
||||
BetaSignatureDelta,
|
||||
BetaTextBlock,
|
||||
BetaTextDelta,
|
||||
BetaThinkingBlock,
|
||||
BetaThinkingDelta,
|
||||
BetaToolUseBlock,
|
||||
)
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.letta_message import (
|
||||
AssistantMessage,
|
||||
HiddenReasoningMessage,
|
||||
LettaMessage,
|
||||
ReasoningMessage,
|
||||
ToolCallDelta,
|
||||
ToolCallMessage,
|
||||
)
|
||||
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
||||
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: These modes aren't used right now - but can be useful we do multiple sequential tool calling within one Claude message
|
||||
class EventMode(Enum):
|
||||
TEXT = "TEXT"
|
||||
TOOL_USE = "TOOL_USE"
|
||||
THINKING = "THINKING"
|
||||
REDACTED_THINKING = "REDACTED_THINKING"
|
||||
|
||||
|
||||
class AnthropicStreamingInterface:
|
||||
"""
|
||||
Encapsulates the logic for streaming responses from Anthropic.
|
||||
This class handles parsing of partial tokens, pre-execution messages,
|
||||
and detection of tool call events.
|
||||
"""
|
||||
|
||||
def __init__(self, use_assistant_message: bool = False, put_inner_thoughts_in_kwarg: bool = False):
|
||||
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
|
||||
self.use_assistant_message = use_assistant_message
|
||||
|
||||
# Premake IDs for database writes
|
||||
self.letta_assistant_message_id = Message.generate_id()
|
||||
self.letta_tool_message_id = Message.generate_id()
|
||||
|
||||
self.anthropic_mode = None
|
||||
self.message_id = None
|
||||
self.accumulated_inner_thoughts = []
|
||||
self.tool_call_id = None
|
||||
self.tool_call_name = None
|
||||
self.accumulated_tool_call_args = []
|
||||
self.previous_parse = {}
|
||||
|
||||
# usage trackers
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
|
||||
# reasoning object trackers
|
||||
self.reasoning_messages = []
|
||||
|
||||
# Buffer to hold tool call messages until inner thoughts are complete
|
||||
self.tool_call_buffer = []
|
||||
self.inner_thoughts_complete = False
|
||||
self.put_inner_thoughts_in_kwarg = put_inner_thoughts_in_kwarg
|
||||
|
||||
def get_tool_call_object(self) -> ToolCall:
|
||||
"""Useful for agent loop"""
|
||||
return ToolCall(
|
||||
id=self.tool_call_id, function=FunctionCall(arguments="".join(self.accumulated_tool_call_args), name=self.tool_call_name)
|
||||
)
|
||||
|
||||
def _check_inner_thoughts_complete(self, combined_args: str) -> bool:
|
||||
"""
|
||||
Check if inner thoughts are complete in the current tool call arguments
|
||||
by looking for a closing quote after the inner_thoughts field
|
||||
"""
|
||||
if not self.put_inner_thoughts_in_kwarg:
|
||||
# None of the things should have inner thoughts in kwargs
|
||||
return True
|
||||
else:
|
||||
parsed = self.optimistic_json_parser.parse(combined_args)
|
||||
# TODO: This will break on tools with 0 input
|
||||
return len(parsed.keys()) > 1 and INNER_THOUGHTS_KWARG in parsed.keys()
|
||||
|
||||
async def process(self, stream: AsyncStream[BetaRawMessageStreamEvent]) -> AsyncGenerator[LettaMessage, None]:
|
||||
async with stream:
|
||||
async for event in stream:
|
||||
# TODO: Support BetaThinkingBlock, BetaRedactedThinkingBlock
|
||||
if isinstance(event, BetaRawContentBlockStartEvent):
|
||||
content = event.content_block
|
||||
|
||||
if isinstance(content, BetaTextBlock):
|
||||
self.anthropic_mode = EventMode.TEXT
|
||||
# TODO: Can capture citations, etc.
|
||||
elif isinstance(content, BetaToolUseBlock):
|
||||
self.anthropic_mode = EventMode.TOOL_USE
|
||||
self.tool_call_id = content.id
|
||||
self.tool_call_name = content.name
|
||||
self.inner_thoughts_complete = False
|
||||
|
||||
if not self.use_assistant_message:
|
||||
# Buffer the initial tool call message instead of yielding immediately
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_tool_message_id,
|
||||
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
self.tool_call_buffer.append(tool_call_msg)
|
||||
elif isinstance(content, BetaThinkingBlock):
|
||||
self.anthropic_mode = EventMode.THINKING
|
||||
# TODO: Can capture signature, etc.
|
||||
elif isinstance(content, BetaRedactedThinkingBlock):
|
||||
self.anthropic_mode = EventMode.REDACTED_THINKING
|
||||
|
||||
hidden_reasoning_message = HiddenReasoningMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
state="redacted",
|
||||
hidden_reasoning=content.data,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
self.reasoning_messages.append(hidden_reasoning_message)
|
||||
yield hidden_reasoning_message
|
||||
|
||||
elif isinstance(event, BetaRawContentBlockDeltaEvent):
|
||||
delta = event.delta
|
||||
|
||||
if isinstance(delta, BetaTextDelta):
|
||||
# Safety check
|
||||
if not self.anthropic_mode == EventMode.TEXT:
|
||||
raise RuntimeError(
|
||||
f"Streaming integrity failed - received BetaTextDelta object while not in TEXT EventMode: {delta}"
|
||||
)
|
||||
|
||||
# TODO: Strip out </thinking> more robustly, this is pretty hacky lol
|
||||
delta.text = delta.text.replace("</thinking>", "")
|
||||
self.accumulated_inner_thoughts.append(delta.text)
|
||||
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
reasoning=self.accumulated_inner_thoughts[-1],
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
yield reasoning_message
|
||||
|
||||
elif isinstance(delta, BetaInputJSONDelta):
|
||||
if not self.anthropic_mode == EventMode.TOOL_USE:
|
||||
raise RuntimeError(
|
||||
f"Streaming integrity failed - received BetaInputJSONDelta object while not in TOOL_USE EventMode: {delta}"
|
||||
)
|
||||
|
||||
self.accumulated_tool_call_args.append(delta.partial_json)
|
||||
combined_args = "".join(self.accumulated_tool_call_args)
|
||||
current_parsed = self.optimistic_json_parser.parse(combined_args)
|
||||
|
||||
# Start detecting a difference in inner thoughts
|
||||
previous_inner_thoughts = self.previous_parse.get(INNER_THOUGHTS_KWARG, "")
|
||||
current_inner_thoughts = current_parsed.get(INNER_THOUGHTS_KWARG, "")
|
||||
inner_thoughts_diff = current_inner_thoughts[len(previous_inner_thoughts) :]
|
||||
|
||||
if inner_thoughts_diff:
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
reasoning=inner_thoughts_diff,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
yield reasoning_message
|
||||
|
||||
# Check if inner thoughts are complete - if so, flush the buffer
|
||||
if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(combined_args):
|
||||
self.inner_thoughts_complete = True
|
||||
# Flush all buffered tool call messages
|
||||
for buffered_msg in self.tool_call_buffer:
|
||||
yield buffered_msg
|
||||
self.tool_call_buffer = []
|
||||
|
||||
# Start detecting special case of "send_message"
|
||||
if self.tool_call_name == DEFAULT_MESSAGE_TOOL and self.use_assistant_message:
|
||||
previous_send_message = self.previous_parse.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
|
||||
current_send_message = current_parsed.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
|
||||
send_message_diff = current_send_message[len(previous_send_message) :]
|
||||
|
||||
# Only stream out if it's not an empty string
|
||||
if send_message_diff:
|
||||
yield AssistantMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
content=[TextContent(text=send_message_diff)],
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
else:
|
||||
# Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_tool_message_id,
|
||||
tool_call=ToolCallDelta(arguments=delta.partial_json),
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
if self.inner_thoughts_complete:
|
||||
yield tool_call_msg
|
||||
else:
|
||||
self.tool_call_buffer.append(tool_call_msg)
|
||||
|
||||
# Set previous parse
|
||||
self.previous_parse = current_parsed
|
||||
elif isinstance(delta, BetaThinkingDelta):
|
||||
# Safety check
|
||||
if not self.anthropic_mode == EventMode.THINKING:
|
||||
raise RuntimeError(
|
||||
f"Streaming integrity failed - received BetaThinkingBlock object while not in THINKING EventMode: {delta}"
|
||||
)
|
||||
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
source="reasoner_model",
|
||||
reasoning=delta.thinking,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
yield reasoning_message
|
||||
elif isinstance(delta, BetaSignatureDelta):
|
||||
# Safety check
|
||||
if not self.anthropic_mode == EventMode.THINKING:
|
||||
raise RuntimeError(
|
||||
f"Streaming integrity failed - received BetaSignatureDelta object while not in THINKING EventMode: {delta}"
|
||||
)
|
||||
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
source="reasoner_model",
|
||||
reasoning="",
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
signature=delta.signature,
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
yield reasoning_message
|
||||
elif isinstance(event, BetaRawMessageStartEvent):
|
||||
self.message_id = event.message.id
|
||||
self.input_tokens += event.message.usage.input_tokens
|
||||
self.output_tokens += event.message.usage.output_tokens
|
||||
elif isinstance(event, BetaRawMessageDeltaEvent):
|
||||
self.output_tokens += event.usage.output_tokens
|
||||
elif isinstance(event, BetaRawMessageStopEvent):
|
||||
# Don't do anything here! We don't want to stop the stream.
|
||||
pass
|
||||
elif isinstance(event, BetaRawContentBlockStopEvent):
|
||||
# If we're exiting a tool use block and there are still buffered messages,
|
||||
# we should flush them now
|
||||
if self.anthropic_mode == EventMode.TOOL_USE and self.tool_call_buffer:
|
||||
for buffered_msg in self.tool_call_buffer:
|
||||
yield buffered_msg
|
||||
self.tool_call_buffer = []
|
||||
|
||||
self.anthropic_mode = None
|
||||
|
||||
def get_reasoning_content(self) -> List[Union[TextContent, ReasoningContent, RedactedReasoningContent]]:
|
||||
def _process_group(
|
||||
group: List[Union[ReasoningMessage, HiddenReasoningMessage]], group_type: str
|
||||
) -> Union[TextContent, ReasoningContent, RedactedReasoningContent]:
|
||||
if group_type == "reasoning":
|
||||
reasoning_text = "".join(chunk.reasoning for chunk in group)
|
||||
is_native = any(chunk.source == "reasoner_model" for chunk in group)
|
||||
signature = next((chunk.signature for chunk in group if chunk.signature is not None), None)
|
||||
if is_native:
|
||||
return ReasoningContent(is_native=is_native, reasoning=reasoning_text, signature=signature)
|
||||
else:
|
||||
return TextContent(text=reasoning_text)
|
||||
elif group_type == "redacted":
|
||||
redacted_text = "".join(chunk.hidden_reasoning for chunk in group if chunk.hidden_reasoning is not None)
|
||||
return RedactedReasoningContent(data=redacted_text)
|
||||
else:
|
||||
raise ValueError("Unexpected group type")
|
||||
|
||||
merged = []
|
||||
current_group = []
|
||||
current_group_type = None # "reasoning" or "redacted"
|
||||
|
||||
for msg in self.reasoning_messages:
|
||||
# Determine the type of the current message
|
||||
if isinstance(msg, HiddenReasoningMessage):
|
||||
msg_type = "redacted"
|
||||
elif isinstance(msg, ReasoningMessage):
|
||||
msg_type = "reasoning"
|
||||
else:
|
||||
raise ValueError("Unexpected message type")
|
||||
|
||||
# Initialize group type if not set
|
||||
if current_group_type is None:
|
||||
current_group_type = msg_type
|
||||
|
||||
# If the type changes, process the current group
|
||||
if msg_type != current_group_type:
|
||||
merged.append(_process_group(current_group, current_group_type))
|
||||
current_group = []
|
||||
current_group_type = msg_type
|
||||
|
||||
current_group.append(msg)
|
||||
|
||||
# Process the final group, if any.
|
||||
if current_group:
|
||||
merged.append(_process_group(current_group, current_group_type))
|
||||
|
||||
return merged
|
||||
@@ -3,7 +3,9 @@ import re
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import anthropic
|
||||
from anthropic import AsyncStream
|
||||
from anthropic.types import Message as AnthropicMessage
|
||||
from anthropic.types.beta import BetaRawMessageStreamEvent
|
||||
|
||||
from letta.errors import (
|
||||
ContextWindowExceededError,
|
||||
@@ -28,6 +30,7 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.schemas.openai.chat_completion_response import Message as ChoiceMessage
|
||||
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
from letta.tracing import trace_method
|
||||
|
||||
DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence."
|
||||
|
||||
@@ -46,18 +49,28 @@ class AnthropicClient(LLMClientBase):
|
||||
response = await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def stream_async(self, request_data: dict) -> AsyncStream[BetaRawMessageStreamEvent]:
|
||||
client = self._get_anthropic_client(async_client=True)
|
||||
request_data["stream"] = True
|
||||
return await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
|
||||
|
||||
@trace_method
|
||||
def _get_anthropic_client(self, async_client: bool = False) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
|
||||
override_key = ProviderManager().get_anthropic_override_key()
|
||||
if async_client:
|
||||
return anthropic.AsyncAnthropic(api_key=override_key) if override_key else anthropic.AsyncAnthropic()
|
||||
return anthropic.Anthropic(api_key=override_key) if override_key else anthropic.Anthropic()
|
||||
|
||||
@trace_method
|
||||
def build_request_data(
|
||||
self,
|
||||
messages: List[PydanticMessage],
|
||||
tools: List[dict],
|
||||
force_tool_call: Optional[str] = None,
|
||||
) -> dict:
|
||||
# TODO: This needs to get cleaned up. The logic here is pretty confusing.
|
||||
# TODO: I really want to get rid of prefixing, it's a recipe for disaster code maintenance wise
|
||||
prefix_fill = True
|
||||
if not self.use_tool_naming:
|
||||
raise NotImplementedError("Only tool calling supported on Anthropic API requests")
|
||||
@@ -73,11 +86,6 @@ class AnthropicClient(LLMClientBase):
|
||||
|
||||
# Extended Thinking
|
||||
if self.llm_config.enable_reasoner:
|
||||
assert (
|
||||
self.llm_config.max_reasoning_tokens is not None and self.llm_config.max_reasoning_tokens < self.llm_config.max_tokens
|
||||
), "max tokens must be greater than thinking budget"
|
||||
assert not self.llm_config.put_inner_thoughts_in_kwargs, "extended thinking not compatible with put_inner_thoughts_in_kwargs"
|
||||
|
||||
data["thinking"] = {
|
||||
"type": "enabled",
|
||||
"budget_tokens": self.llm_config.max_reasoning_tokens,
|
||||
@@ -89,15 +97,35 @@ class AnthropicClient(LLMClientBase):
|
||||
prefix_fill = False
|
||||
|
||||
# Tools
|
||||
tools_for_request = (
|
||||
[Tool(function=f) for f in tools if f["name"] == force_tool_call]
|
||||
if force_tool_call is not None
|
||||
else [Tool(function=f) for f in tools]
|
||||
)
|
||||
if force_tool_call is not None:
|
||||
self.llm_config.put_inner_thoughts_in_kwargs = True # why do we do this ?
|
||||
# For an overview on tool choice:
|
||||
# https://docs.anthropic.com/en/docs/build-with-claude/tool-use/overview
|
||||
if not tools:
|
||||
# Special case for summarization path
|
||||
tools_for_request = None
|
||||
tool_choice = None
|
||||
elif force_tool_call is not None:
|
||||
tool_choice = {"type": "tool", "name": force_tool_call}
|
||||
tools_for_request = [Tool(function=f) for f in tools if f["name"] == force_tool_call]
|
||||
|
||||
# need to have this setting to be able to put inner thoughts in kwargs
|
||||
if not self.llm_config.put_inner_thoughts_in_kwargs:
|
||||
logger.warning(
|
||||
f"Force setting put_inner_thoughts_in_kwargs to True for Claude because there is a forced tool call: {force_tool_call}"
|
||||
)
|
||||
self.llm_config.put_inner_thoughts_in_kwargs = True
|
||||
else:
|
||||
if self.llm_config.put_inner_thoughts_in_kwargs:
|
||||
# tool_choice_type other than "auto" only plays nice if thinking goes inside the tool calls
|
||||
tool_choice = {"type": "any", "disable_parallel_tool_use": True}
|
||||
else:
|
||||
tool_choice = {"type": "auto", "disable_parallel_tool_use": True}
|
||||
tools_for_request = [Tool(function=f) for f in tools] if tools is not None else None
|
||||
|
||||
# Add tool choice
|
||||
data["tool_choice"] = tool_choice
|
||||
|
||||
# Add inner thoughts kwarg
|
||||
# TODO: Can probably make this more efficient
|
||||
if len(tools_for_request) > 0 and self.llm_config.put_inner_thoughts_in_kwargs:
|
||||
tools_with_inner_thoughts = add_inner_thoughts_to_functions(
|
||||
functions=[t.function.model_dump() for t in tools_for_request],
|
||||
|
||||
@@ -30,10 +30,11 @@ class JobStatus(str, Enum):
|
||||
|
||||
|
||||
class MessageStreamStatus(str, Enum):
|
||||
# done_generation = "[DONE_GEN]"
|
||||
# done_step = "[DONE_STEP]"
|
||||
done = "[DONE]"
|
||||
|
||||
def model_dump_json(self):
|
||||
return "[DONE]"
|
||||
|
||||
|
||||
class ToolRuleType(str, Enum):
|
||||
"""
|
||||
|
||||
@@ -123,9 +123,9 @@ class ToolCall(BaseModel):
|
||||
|
||||
|
||||
class ToolCallDelta(BaseModel):
|
||||
name: Optional[str]
|
||||
arguments: Optional[str]
|
||||
tool_call_id: Optional[str]
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
tool_call_id: Optional[str] = None
|
||||
|
||||
def model_dump(self, *args, **kwargs):
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, root_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
@@ -70,7 +70,8 @@ class LLMConfig(BaseModel):
|
||||
# FIXME hack to silence pydantic protected namespace warning
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@root_validator(pre=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def set_default_put_inner_thoughts(cls, values):
|
||||
"""
|
||||
Dynamically set the default for put_inner_thoughts_in_kwargs based on the model field,
|
||||
@@ -79,15 +80,24 @@ class LLMConfig(BaseModel):
|
||||
model = values.get("model")
|
||||
|
||||
# Define models where we want put_inner_thoughts_in_kwargs to be False
|
||||
# For now it is gpt-4
|
||||
avoid_put_inner_thoughts_in_kwargs = ["gpt-4"]
|
||||
|
||||
# Only modify the value if it's None or not provided
|
||||
if values.get("put_inner_thoughts_in_kwargs") is None:
|
||||
values["put_inner_thoughts_in_kwargs"] = False if model in avoid_put_inner_thoughts_in_kwargs else True
|
||||
|
||||
return values
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_reasoning_constraints(self) -> "LLMConfig":
|
||||
if self.enable_reasoner:
|
||||
if self.max_reasoning_tokens is None:
|
||||
raise ValueError("max_reasoning_tokens must be set when enable_reasoner is True")
|
||||
if self.max_tokens is not None and self.max_reasoning_tokens >= self.max_tokens:
|
||||
raise ValueError("max_tokens must be greater than max_reasoning_tokens (thinking budget)")
|
||||
if self.put_inner_thoughts_in_kwargs:
|
||||
raise ValueError("Extended thinking is not compatible with put_inner_thoughts_in_kwargs")
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def default_config(cls, model_name: str):
|
||||
"""
|
||||
|
||||
@@ -8,6 +8,7 @@ from fastapi.responses import JSONResponse
|
||||
from marshmallow import ValidationError
|
||||
from pydantic import Field
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
@@ -30,7 +31,6 @@ from letta.schemas.user import User
|
||||
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import settings
|
||||
|
||||
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally
|
||||
|
||||
@@ -590,8 +590,10 @@ async def send_message(
|
||||
This endpoint accepts a message from a user and processes it through the agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
if settings.use_experimental:
|
||||
logger.warning("USING EXPERIMENTAL!")
|
||||
# TODO: This is redundant, remove soon
|
||||
agent = server.agent_manager.get_agent_by_id(agent_id, actor)
|
||||
|
||||
if agent.llm_config.model_endpoint_type == "anthropic":
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
@@ -644,17 +646,39 @@ async def send_message_streaming(
|
||||
It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
result = await server.send_message_to_agent(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
messages=request.messages,
|
||||
stream_steps=True,
|
||||
stream_tokens=request.stream_tokens,
|
||||
# Support for AssistantMessage
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
)
|
||||
# TODO: This is redundant, remove soon
|
||||
agent = server.agent_manager.get_agent_by_id(agent_id, actor)
|
||||
|
||||
if agent.llm_config.model_endpoint_type == "anthropic":
|
||||
logger.warning("USING EXPERIMENTAL!")
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
messages = request.messages
|
||||
content = messages[0].content[0].text if messages and not isinstance(messages[0].content, str) else messages[0].content
|
||||
result = StreamingResponse(
|
||||
experimental_agent.step_stream(UserMessage(content=content), max_steps=10, use_assistant_message=request.use_assistant_message),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
result = await server.send_message_to_agent(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
messages=request.messages,
|
||||
stream_steps=True,
|
||||
stream_tokens=request.stream_tokens,
|
||||
# Support for AssistantMessage
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -670,31 +694,17 @@ async def process_message_background(
|
||||
) -> None:
|
||||
"""Background task to process the message and update job status."""
|
||||
try:
|
||||
# TODO(matt) we should probably make this stream_steps and log each step as it progresses, so the job update GET can see the total steps so far + partial usage?
|
||||
if settings.use_experimental:
|
||||
logger.warning("USING EXPERIMENTAL!")
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
)
|
||||
content = messages[0].content[0].text if messages and not isinstance(messages[0].content, str) else messages[0].content
|
||||
result = await experimental_agent.step(UserMessage(content=content), max_steps=10)
|
||||
else:
|
||||
result = await server.send_message_to_agent(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
messages=messages,
|
||||
stream_steps=False, # NOTE(matt)
|
||||
stream_tokens=False,
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
metadata={"job_id": job_id}, # Pass job_id through metadata
|
||||
)
|
||||
result = await server.send_message_to_agent(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
messages=messages,
|
||||
stream_steps=False, # NOTE(matt)
|
||||
stream_tokens=False,
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
metadata={"job_id": job_id}, # Pass job_id through metadata
|
||||
)
|
||||
|
||||
# Update job status to completed
|
||||
job_update = JobUpdate(
|
||||
|
||||
@@ -18,7 +18,7 @@ from letta.errors import ContextWindowExceededError, RateLimitExceededError
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
@@ -167,7 +167,7 @@ def create_user_message(input_message: dict, agent_id: str, actor: User) -> Mess
|
||||
return user_message
|
||||
|
||||
|
||||
def create_tool_call_messages_from_openai_response(
|
||||
def create_letta_messages_from_llm_response(
|
||||
agent_id: str,
|
||||
model: str,
|
||||
function_name: str,
|
||||
@@ -177,6 +177,9 @@ def create_tool_call_messages_from_openai_response(
|
||||
function_response: Optional[str],
|
||||
actor: User,
|
||||
add_heartbeat_request_system_message: bool = False,
|
||||
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
|
||||
pre_computed_assistant_message_id: Optional[str] = None,
|
||||
pre_computed_tool_message_id: Optional[str] = None,
|
||||
) -> List[Message]:
|
||||
messages = []
|
||||
|
||||
@@ -190,9 +193,11 @@ def create_tool_call_messages_from_openai_response(
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
# TODO: Use ToolCallContent instead of tool_calls
|
||||
# TODO: This helps preserve ordering
|
||||
assistant_message = Message(
|
||||
role=MessageRole.assistant,
|
||||
content=[],
|
||||
content=reasoning_content if reasoning_content else [],
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
@@ -200,8 +205,12 @@ def create_tool_call_messages_from_openai_response(
|
||||
tool_call_id=tool_call_id,
|
||||
created_at=get_utc_time(),
|
||||
)
|
||||
if pre_computed_assistant_message_id:
|
||||
assistant_message.id = pre_computed_assistant_message_id
|
||||
messages.append(assistant_message)
|
||||
|
||||
# TODO: Use ToolReturnContent instead of TextContent
|
||||
# TODO: This helps preserve ordering
|
||||
tool_message = Message(
|
||||
role=MessageRole.tool,
|
||||
content=[TextContent(text=package_function_response(function_call_success, function_response))],
|
||||
@@ -212,6 +221,8 @@ def create_tool_call_messages_from_openai_response(
|
||||
tool_call_id=tool_call_id,
|
||||
created_at=get_utc_time(),
|
||||
)
|
||||
if pre_computed_tool_message_id:
|
||||
tool_message.id = pre_computed_tool_message_id
|
||||
messages.append(tool_message)
|
||||
|
||||
if add_heartbeat_request_system_message:
|
||||
@@ -243,7 +254,7 @@ def create_assistant_messages_from_openai_response(
|
||||
"""
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
|
||||
return create_tool_call_messages_from_openai_response(
|
||||
return create_letta_messages_from_llm_response(
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
function_name=DEFAULT_MESSAGE_TOOL,
|
||||
|
||||
@@ -41,6 +41,7 @@ class LettaCoreToolExecutor(ToolExecutor):
|
||||
"send_message": self.send_message,
|
||||
"conversation_search": self.conversation_search,
|
||||
"archival_memory_search": self.archival_memory_search,
|
||||
"archival_memory_insert": self.archival_memory_insert,
|
||||
}
|
||||
|
||||
if function_name not in function_map:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import concurrent
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
@@ -8,7 +7,7 @@ import httpx
|
||||
import openai
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta
|
||||
from letta_client import CreateBlock, Letta
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
@@ -425,23 +424,44 @@ def run_supervisor_worker_group(client: Letta, weather_tool, group_id: str):
|
||||
return response
|
||||
|
||||
|
||||
import concurrent.futures
|
||||
def test_anthropic_streaming(client: Letta):
|
||||
agent_name = "anthropic_tester"
|
||||
|
||||
existing_agents = client.agents.list(tags=[agent_name])
|
||||
for worker in existing_agents:
|
||||
client.agents.delete(agent_id=worker.id)
|
||||
|
||||
def test_multi_agent_broadcast_parallel(client: Letta, weather_tool):
|
||||
start_time = time.time()
|
||||
num_groups = 5
|
||||
llm_config = LLMConfig(
|
||||
model="claude-3-7-sonnet-20250219",
|
||||
model_endpoint_type="anthropic",
|
||||
model_endpoint="https://api.anthropic.com/v1",
|
||||
context_window=32000,
|
||||
handle=f"anthropic/claude-3-5-sonnet-20241022",
|
||||
put_inner_thoughts_in_kwargs=False,
|
||||
max_tokens=4096,
|
||||
enable_reasoner=True,
|
||||
max_reasoning_tokens=1024,
|
||||
)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_groups) as executor:
|
||||
futures = []
|
||||
for i in range(num_groups):
|
||||
group_id = str(uuid.uuid4())[:8]
|
||||
futures.append(executor.submit(run_supervisor_worker_group, client, weather_tool, group_id))
|
||||
agent = client.agents.create(
|
||||
name=agent_name,
|
||||
tags=[agent_name],
|
||||
include_base_tools=True,
|
||||
embedding="letta/letta-free",
|
||||
llm_config=llm_config,
|
||||
memory_blocks=[CreateBlock(label="human", value="")],
|
||||
# tool_rules=[InitToolRule(tool_name="core_memory_append")]
|
||||
)
|
||||
|
||||
results = [f.result() for f in futures]
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Use core memory append to append `banana` to the persona core memory.",
|
||||
}
|
||||
],
|
||||
stream_tokens=True,
|
||||
)
|
||||
|
||||
# Optionally: assert something or log runtimes
|
||||
print(f"Executed {num_groups} supervisor-worker groups in parallel.")
|
||||
print(f"Total runtime: {time.time() - start_time:.2f} seconds")
|
||||
for idx, r in enumerate(results):
|
||||
assert r is not None, f"Group {idx} returned no response"
|
||||
print(list(response))
|
||||
|
||||
Reference in New Issue
Block a user