fix: user messages on new agent loop are not processed in ADE (includes new json parser) (#1934)
This commit is contained in:
@@ -28,7 +28,7 @@ from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.helpers.message_helper import prepare_input_message_create
|
||||
from letta.helpers.message_helper import convert_message_creates_to_messages
|
||||
from letta.interface import AgentInterface
|
||||
from letta.llm_api.helpers import calculate_summarizer_cutoff, get_token_counts_for_messages, is_context_overflow_error
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
@@ -726,8 +726,7 @@ class Agent(BaseAgent):
|
||||
self.tool_rules_solver.clear_tool_history()
|
||||
|
||||
# Convert MessageCreate objects to Message objects
|
||||
message_objects = [prepare_input_message_create(m, self.agent_state.id, True, True) for m in input_messages]
|
||||
next_input_messages = message_objects
|
||||
next_input_messages = convert_message_creates_to_messages(input_messages, self.agent_state.id)
|
||||
counter = 0
|
||||
total_usage = UsageStatistics()
|
||||
step_count = 0
|
||||
|
||||
@@ -109,7 +109,7 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
llm_client = LLMClient.create(
|
||||
llm_config=agent_state.llm_config,
|
||||
provider=agent_state.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=True,
|
||||
)
|
||||
|
||||
@@ -125,7 +125,7 @@ class LettaAgent(BaseAgent):
|
||||
# TODO: THIS IS INCREDIBLY UGLY
|
||||
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
|
||||
interface = AnthropicStreamingInterface(
|
||||
use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=llm_client.llm_config.put_inner_thoughts_in_kwargs
|
||||
use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs
|
||||
)
|
||||
async for chunk in interface.process(stream):
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
@@ -275,45 +275,49 @@ class LettaAgent(BaseAgent):
|
||||
return persisted_messages, continue_stepping
|
||||
|
||||
def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
|
||||
self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor)
|
||||
try:
|
||||
self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor)
|
||||
|
||||
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
|
||||
curr_system_message = in_context_messages[0]
|
||||
curr_memory_str = agent_state.memory.compile()
|
||||
curr_system_message_text = curr_system_message.content[0].text
|
||||
if curr_memory_str in curr_system_message_text:
|
||||
# NOTE: could this cause issues if a block is removed? (substring match would still work)
|
||||
logger.debug(
|
||||
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
|
||||
)
|
||||
return in_context_messages
|
||||
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
|
||||
curr_system_message = in_context_messages[0]
|
||||
curr_memory_str = agent_state.memory.compile()
|
||||
curr_system_message_text = curr_system_message.content[0].text
|
||||
if curr_memory_str in curr_system_message_text:
|
||||
# NOTE: could this cause issues if a block is removed? (substring match would still work)
|
||||
logger.debug(
|
||||
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
|
||||
)
|
||||
return in_context_messages
|
||||
|
||||
memory_edit_timestamp = get_utc_time()
|
||||
memory_edit_timestamp = get_utc_time()
|
||||
|
||||
num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
|
||||
new_system_message_str = compile_system_message(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
previous_message_count=num_messages,
|
||||
archival_memory_size=num_archival_memories,
|
||||
)
|
||||
|
||||
diff = united_diff(curr_system_message_text, new_system_message_str)
|
||||
if len(diff) > 0:
|
||||
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
|
||||
new_system_message = self.message_manager.update_message_by_id(
|
||||
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
|
||||
new_system_message_str = compile_system_message(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
previous_message_count=num_messages,
|
||||
archival_memory_size=num_archival_memories,
|
||||
)
|
||||
|
||||
# Skip pulling down the agent's memory again to save on a db call
|
||||
return [new_system_message] + in_context_messages[1:]
|
||||
diff = united_diff(curr_system_message_text, new_system_message_str)
|
||||
if len(diff) > 0:
|
||||
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
|
||||
else:
|
||||
return in_context_messages
|
||||
new_system_message = self.message_manager.update_message_by_id(
|
||||
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
|
||||
)
|
||||
|
||||
# Skip pulling down the agent's memory again to save on a db call
|
||||
return [new_system_message] + in_context_messages[1:]
|
||||
|
||||
else:
|
||||
return in_context_messages
|
||||
except:
|
||||
logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})")
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]:
|
||||
|
||||
@@ -39,10 +39,10 @@ def generate_langchain_tool_wrapper(
|
||||
) -> tuple[str, str]:
|
||||
tool_name = tool.__class__.__name__
|
||||
import_statement = f"from langchain_community.tools import {tool_name}"
|
||||
extra_module_imports = generate_import_code(additional_imports_module_attr_map)
|
||||
extra_module_imports = _generate_import_code(additional_imports_module_attr_map)
|
||||
|
||||
# Safety check that user has passed in all required imports:
|
||||
assert_all_classes_are_imported(tool, additional_imports_module_attr_map)
|
||||
_assert_all_classes_are_imported(tool, additional_imports_module_attr_map)
|
||||
|
||||
tool_instantiation = f"tool = {generate_imported_tool_instantiation_call_str(tool)}"
|
||||
run_call = f"return tool._run(**kwargs)"
|
||||
@@ -71,7 +71,7 @@ def _assert_code_gen_compilable(code_str):
|
||||
print(f"Syntax error in code: {e}")
|
||||
|
||||
|
||||
def assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional_imports_module_attr_map: dict[str, str]) -> None:
|
||||
def _assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional_imports_module_attr_map: dict[str, str]) -> None:
|
||||
# Safety check that user has passed in all required imports:
|
||||
tool_name = tool.__class__.__name__
|
||||
current_class_imports = {tool_name}
|
||||
@@ -193,7 +193,7 @@ def _is_base_model(obj: Any):
|
||||
return isinstance(obj, BaseModel)
|
||||
|
||||
|
||||
def generate_import_code(module_attr_map: Optional[dict]):
|
||||
def _generate_import_code(module_attr_map: Optional[dict]):
|
||||
if not module_attr_map:
|
||||
return ""
|
||||
|
||||
@@ -295,7 +295,7 @@ async def _send_message_to_agent_no_stream(
|
||||
return LettaResponse(messages=final_messages, usage=usage_stats)
|
||||
|
||||
|
||||
async def async_send_message_with_retries(
|
||||
async def _async_send_message_with_retries(
|
||||
server: "SyncServer",
|
||||
sender_agent: "Agent",
|
||||
target_agent_id: str,
|
||||
|
||||
@@ -4,7 +4,24 @@ from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
|
||||
|
||||
def prepare_input_message_create(
|
||||
def convert_message_creates_to_messages(
|
||||
messages: list[MessageCreate],
|
||||
agent_id: str,
|
||||
wrap_user_message: bool = True,
|
||||
wrap_system_message: bool = True,
|
||||
) -> list[Message]:
|
||||
return [
|
||||
_convert_message_create_to_message(
|
||||
message=message,
|
||||
agent_id=agent_id,
|
||||
wrap_user_message=wrap_user_message,
|
||||
wrap_system_message=wrap_system_message,
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
|
||||
|
||||
def _convert_message_create_to_message(
|
||||
message: MessageCreate,
|
||||
agent_id: str,
|
||||
wrap_user_message: bool = True,
|
||||
@@ -23,12 +40,12 @@ def prepare_input_message_create(
|
||||
raise ValueError("Message content is empty or invalid")
|
||||
|
||||
# Apply wrapping if needed
|
||||
if message.role == MessageRole.user and wrap_user_message:
|
||||
if message.role not in {MessageRole.user, MessageRole.system}:
|
||||
raise ValueError(f"Invalid message role: {message.role}")
|
||||
elif message.role == MessageRole.user and wrap_user_message:
|
||||
message_content = system.package_user_message(user_message=message_content)
|
||||
elif message.role == MessageRole.system and wrap_system_message:
|
||||
message_content = system.package_system_message(system_message=message_content)
|
||||
elif message.role not in {MessageRole.user, MessageRole.system}:
|
||||
raise ValueError(f"Invalid message role: {message.role}")
|
||||
|
||||
return Message(
|
||||
agent_id=agent_id,
|
||||
|
||||
@@ -35,7 +35,7 @@ from letta.schemas.letta_message import (
|
||||
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
||||
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
|
||||
from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -56,7 +56,7 @@ class AnthropicStreamingInterface:
|
||||
"""
|
||||
|
||||
def __init__(self, use_assistant_message: bool = False, put_inner_thoughts_in_kwarg: bool = False):
|
||||
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
|
||||
self.json_parser: JSONParser = PydanticJSONParser()
|
||||
self.use_assistant_message = use_assistant_message
|
||||
|
||||
# Premake IDs for database writes
|
||||
@@ -68,7 +68,7 @@ class AnthropicStreamingInterface:
|
||||
self.accumulated_inner_thoughts = []
|
||||
self.tool_call_id = None
|
||||
self.tool_call_name = None
|
||||
self.accumulated_tool_call_args = []
|
||||
self.accumulated_tool_call_args = ""
|
||||
self.previous_parse = {}
|
||||
|
||||
# usage trackers
|
||||
@@ -85,193 +85,200 @@ class AnthropicStreamingInterface:
|
||||
|
||||
def get_tool_call_object(self) -> ToolCall:
|
||||
"""Useful for agent loop"""
|
||||
return ToolCall(
|
||||
id=self.tool_call_id, function=FunctionCall(arguments="".join(self.accumulated_tool_call_args), name=self.tool_call_name)
|
||||
)
|
||||
return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=self.accumulated_tool_call_args, name=self.tool_call_name))
|
||||
|
||||
def _check_inner_thoughts_complete(self, combined_args: str) -> bool:
|
||||
"""
|
||||
Check if inner thoughts are complete in the current tool call arguments
|
||||
by looking for a closing quote after the inner_thoughts field
|
||||
"""
|
||||
if not self.put_inner_thoughts_in_kwarg:
|
||||
# None of the things should have inner thoughts in kwargs
|
||||
return True
|
||||
else:
|
||||
parsed = self.optimistic_json_parser.parse(combined_args)
|
||||
# TODO: This will break on tools with 0 input
|
||||
return len(parsed.keys()) > 1 and INNER_THOUGHTS_KWARG in parsed.keys()
|
||||
try:
|
||||
if not self.put_inner_thoughts_in_kwarg:
|
||||
# None of the things should have inner thoughts in kwargs
|
||||
return True
|
||||
else:
|
||||
parsed = self.json_parser.parse(combined_args)
|
||||
# TODO: This will break on tools with 0 input
|
||||
return len(parsed.keys()) > 1 and INNER_THOUGHTS_KWARG in parsed.keys()
|
||||
except Exception as e:
|
||||
logger.error("Error checking inner thoughts: %s", e)
|
||||
raise
|
||||
|
||||
async def process(self, stream: AsyncStream[BetaRawMessageStreamEvent]) -> AsyncGenerator[LettaMessage, None]:
|
||||
async with stream:
|
||||
async for event in stream:
|
||||
# TODO: Support BetaThinkingBlock, BetaRedactedThinkingBlock
|
||||
if isinstance(event, BetaRawContentBlockStartEvent):
|
||||
content = event.content_block
|
||||
try:
|
||||
async with stream:
|
||||
async for event in stream:
|
||||
# TODO: Support BetaThinkingBlock, BetaRedactedThinkingBlock
|
||||
if isinstance(event, BetaRawContentBlockStartEvent):
|
||||
content = event.content_block
|
||||
|
||||
if isinstance(content, BetaTextBlock):
|
||||
self.anthropic_mode = EventMode.TEXT
|
||||
# TODO: Can capture citations, etc.
|
||||
elif isinstance(content, BetaToolUseBlock):
|
||||
self.anthropic_mode = EventMode.TOOL_USE
|
||||
self.tool_call_id = content.id
|
||||
self.tool_call_name = content.name
|
||||
self.inner_thoughts_complete = False
|
||||
if isinstance(content, BetaTextBlock):
|
||||
self.anthropic_mode = EventMode.TEXT
|
||||
# TODO: Can capture citations, etc.
|
||||
elif isinstance(content, BetaToolUseBlock):
|
||||
self.anthropic_mode = EventMode.TOOL_USE
|
||||
self.tool_call_id = content.id
|
||||
self.tool_call_name = content.name
|
||||
self.inner_thoughts_complete = False
|
||||
|
||||
if not self.use_assistant_message:
|
||||
# Buffer the initial tool call message instead of yielding immediately
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_tool_message_id,
|
||||
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
|
||||
if not self.use_assistant_message:
|
||||
# Buffer the initial tool call message instead of yielding immediately
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_tool_message_id,
|
||||
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
self.tool_call_buffer.append(tool_call_msg)
|
||||
elif isinstance(content, BetaThinkingBlock):
|
||||
self.anthropic_mode = EventMode.THINKING
|
||||
# TODO: Can capture signature, etc.
|
||||
elif isinstance(content, BetaRedactedThinkingBlock):
|
||||
self.anthropic_mode = EventMode.REDACTED_THINKING
|
||||
|
||||
hidden_reasoning_message = HiddenReasoningMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
state="redacted",
|
||||
hidden_reasoning=content.data,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
self.tool_call_buffer.append(tool_call_msg)
|
||||
elif isinstance(content, BetaThinkingBlock):
|
||||
self.anthropic_mode = EventMode.THINKING
|
||||
# TODO: Can capture signature, etc.
|
||||
elif isinstance(content, BetaRedactedThinkingBlock):
|
||||
self.anthropic_mode = EventMode.REDACTED_THINKING
|
||||
self.reasoning_messages.append(hidden_reasoning_message)
|
||||
yield hidden_reasoning_message
|
||||
|
||||
hidden_reasoning_message = HiddenReasoningMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
state="redacted",
|
||||
hidden_reasoning=content.data,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
self.reasoning_messages.append(hidden_reasoning_message)
|
||||
yield hidden_reasoning_message
|
||||
elif isinstance(event, BetaRawContentBlockDeltaEvent):
|
||||
delta = event.delta
|
||||
|
||||
elif isinstance(event, BetaRawContentBlockDeltaEvent):
|
||||
delta = event.delta
|
||||
if isinstance(delta, BetaTextDelta):
|
||||
# Safety check
|
||||
if not self.anthropic_mode == EventMode.TEXT:
|
||||
raise RuntimeError(
|
||||
f"Streaming integrity failed - received BetaTextDelta object while not in TEXT EventMode: {delta}"
|
||||
)
|
||||
|
||||
if isinstance(delta, BetaTextDelta):
|
||||
# Safety check
|
||||
if not self.anthropic_mode == EventMode.TEXT:
|
||||
raise RuntimeError(
|
||||
f"Streaming integrity failed - received BetaTextDelta object while not in TEXT EventMode: {delta}"
|
||||
)
|
||||
# TODO: Strip out </thinking> more robustly, this is pretty hacky lol
|
||||
delta.text = delta.text.replace("</thinking>", "")
|
||||
self.accumulated_inner_thoughts.append(delta.text)
|
||||
|
||||
# TODO: Strip out </thinking> more robustly, this is pretty hacky lol
|
||||
delta.text = delta.text.replace("</thinking>", "")
|
||||
self.accumulated_inner_thoughts.append(delta.text)
|
||||
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
reasoning=self.accumulated_inner_thoughts[-1],
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
yield reasoning_message
|
||||
|
||||
elif isinstance(delta, BetaInputJSONDelta):
|
||||
if not self.anthropic_mode == EventMode.TOOL_USE:
|
||||
raise RuntimeError(
|
||||
f"Streaming integrity failed - received BetaInputJSONDelta object while not in TOOL_USE EventMode: {delta}"
|
||||
)
|
||||
|
||||
self.accumulated_tool_call_args.append(delta.partial_json)
|
||||
combined_args = "".join(self.accumulated_tool_call_args)
|
||||
current_parsed = self.optimistic_json_parser.parse(combined_args)
|
||||
|
||||
# Start detecting a difference in inner thoughts
|
||||
previous_inner_thoughts = self.previous_parse.get(INNER_THOUGHTS_KWARG, "")
|
||||
current_inner_thoughts = current_parsed.get(INNER_THOUGHTS_KWARG, "")
|
||||
inner_thoughts_diff = current_inner_thoughts[len(previous_inner_thoughts) :]
|
||||
|
||||
if inner_thoughts_diff:
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
reasoning=inner_thoughts_diff,
|
||||
reasoning=self.accumulated_inner_thoughts[-1],
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
yield reasoning_message
|
||||
|
||||
# Check if inner thoughts are complete - if so, flush the buffer
|
||||
if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(combined_args):
|
||||
self.inner_thoughts_complete = True
|
||||
# Flush all buffered tool call messages
|
||||
elif isinstance(delta, BetaInputJSONDelta):
|
||||
if not self.anthropic_mode == EventMode.TOOL_USE:
|
||||
raise RuntimeError(
|
||||
f"Streaming integrity failed - received BetaInputJSONDelta object while not in TOOL_USE EventMode: {delta}"
|
||||
)
|
||||
|
||||
self.accumulated_tool_call_args += delta.partial_json
|
||||
current_parsed = self.json_parser.parse(self.accumulated_tool_call_args)
|
||||
|
||||
# Start detecting a difference in inner thoughts
|
||||
previous_inner_thoughts = self.previous_parse.get(INNER_THOUGHTS_KWARG, "")
|
||||
current_inner_thoughts = current_parsed.get(INNER_THOUGHTS_KWARG, "")
|
||||
inner_thoughts_diff = current_inner_thoughts[len(previous_inner_thoughts) :]
|
||||
|
||||
if inner_thoughts_diff:
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
reasoning=inner_thoughts_diff,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
yield reasoning_message
|
||||
|
||||
# Check if inner thoughts are complete - if so, flush the buffer
|
||||
if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(self.accumulated_tool_call_args):
|
||||
self.inner_thoughts_complete = True
|
||||
# Flush all buffered tool call messages
|
||||
for buffered_msg in self.tool_call_buffer:
|
||||
yield buffered_msg
|
||||
self.tool_call_buffer = []
|
||||
|
||||
# Start detecting special case of "send_message"
|
||||
if self.tool_call_name == DEFAULT_MESSAGE_TOOL and self.use_assistant_message:
|
||||
previous_send_message = self.previous_parse.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
|
||||
current_send_message = current_parsed.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
|
||||
send_message_diff = current_send_message[len(previous_send_message) :]
|
||||
|
||||
# Only stream out if it's not an empty string
|
||||
if send_message_diff:
|
||||
yield AssistantMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
content=[TextContent(text=send_message_diff)],
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
else:
|
||||
# Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_tool_message_id,
|
||||
tool_call=ToolCallDelta(arguments=delta.partial_json),
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
if self.inner_thoughts_complete:
|
||||
yield tool_call_msg
|
||||
else:
|
||||
self.tool_call_buffer.append(tool_call_msg)
|
||||
|
||||
# Set previous parse
|
||||
self.previous_parse = current_parsed
|
||||
elif isinstance(delta, BetaThinkingDelta):
|
||||
# Safety check
|
||||
if not self.anthropic_mode == EventMode.THINKING:
|
||||
raise RuntimeError(
|
||||
f"Streaming integrity failed - received BetaThinkingBlock object while not in THINKING EventMode: {delta}"
|
||||
)
|
||||
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
source="reasoner_model",
|
||||
reasoning=delta.thinking,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
yield reasoning_message
|
||||
elif isinstance(delta, BetaSignatureDelta):
|
||||
# Safety check
|
||||
if not self.anthropic_mode == EventMode.THINKING:
|
||||
raise RuntimeError(
|
||||
f"Streaming integrity failed - received BetaSignatureDelta object while not in THINKING EventMode: {delta}"
|
||||
)
|
||||
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
source="reasoner_model",
|
||||
reasoning="",
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
signature=delta.signature,
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
yield reasoning_message
|
||||
elif isinstance(event, BetaRawMessageStartEvent):
|
||||
self.message_id = event.message.id
|
||||
self.input_tokens += event.message.usage.input_tokens
|
||||
self.output_tokens += event.message.usage.output_tokens
|
||||
elif isinstance(event, BetaRawMessageDeltaEvent):
|
||||
self.output_tokens += event.usage.output_tokens
|
||||
elif isinstance(event, BetaRawMessageStopEvent):
|
||||
# Don't do anything here! We don't want to stop the stream.
|
||||
pass
|
||||
elif isinstance(event, BetaRawContentBlockStopEvent):
|
||||
# If we're exiting a tool use block and there are still buffered messages,
|
||||
# we should flush them now
|
||||
if self.anthropic_mode == EventMode.TOOL_USE and self.tool_call_buffer:
|
||||
for buffered_msg in self.tool_call_buffer:
|
||||
yield buffered_msg
|
||||
self.tool_call_buffer = []
|
||||
|
||||
# Start detecting special case of "send_message"
|
||||
if self.tool_call_name == DEFAULT_MESSAGE_TOOL and self.use_assistant_message:
|
||||
previous_send_message = self.previous_parse.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
|
||||
current_send_message = current_parsed.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
|
||||
send_message_diff = current_send_message[len(previous_send_message) :]
|
||||
|
||||
# Only stream out if it's not an empty string
|
||||
if send_message_diff:
|
||||
yield AssistantMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
content=[TextContent(text=send_message_diff)],
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
else:
|
||||
# Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_tool_message_id,
|
||||
tool_call=ToolCallDelta(arguments=delta.partial_json),
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
if self.inner_thoughts_complete:
|
||||
yield tool_call_msg
|
||||
else:
|
||||
self.tool_call_buffer.append(tool_call_msg)
|
||||
|
||||
# Set previous parse
|
||||
self.previous_parse = current_parsed
|
||||
elif isinstance(delta, BetaThinkingDelta):
|
||||
# Safety check
|
||||
if not self.anthropic_mode == EventMode.THINKING:
|
||||
raise RuntimeError(
|
||||
f"Streaming integrity failed - received BetaThinkingBlock object while not in THINKING EventMode: {delta}"
|
||||
)
|
||||
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
source="reasoner_model",
|
||||
reasoning=delta.thinking,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
yield reasoning_message
|
||||
elif isinstance(delta, BetaSignatureDelta):
|
||||
# Safety check
|
||||
if not self.anthropic_mode == EventMode.THINKING:
|
||||
raise RuntimeError(
|
||||
f"Streaming integrity failed - received BetaSignatureDelta object while not in THINKING EventMode: {delta}"
|
||||
)
|
||||
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
source="reasoner_model",
|
||||
reasoning="",
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
signature=delta.signature,
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
yield reasoning_message
|
||||
elif isinstance(event, BetaRawMessageStartEvent):
|
||||
self.message_id = event.message.id
|
||||
self.input_tokens += event.message.usage.input_tokens
|
||||
self.output_tokens += event.message.usage.output_tokens
|
||||
elif isinstance(event, BetaRawMessageDeltaEvent):
|
||||
self.output_tokens += event.usage.output_tokens
|
||||
elif isinstance(event, BetaRawMessageStopEvent):
|
||||
# Don't do anything here! We don't want to stop the stream.
|
||||
pass
|
||||
elif isinstance(event, BetaRawContentBlockStopEvent):
|
||||
# If we're exiting a tool use block and there are still buffered messages,
|
||||
# we should flush them now
|
||||
if self.anthropic_mode == EventMode.TOOL_USE and self.tool_call_buffer:
|
||||
for buffered_msg in self.tool_call_buffer:
|
||||
yield buffered_msg
|
||||
self.tool_call_buffer = []
|
||||
|
||||
self.anthropic_mode = None
|
||||
self.anthropic_mode = None
|
||||
except Exception as e:
|
||||
logger.error("Error processing stream: %s", e)
|
||||
raise
|
||||
finally:
|
||||
logger.info("AnthropicStreamingInterface: Stream processing complete.")
|
||||
|
||||
def get_reasoning_content(self) -> List[Union[TextContent, ReasoningContent, RedactedReasoningContent]]:
|
||||
def _process_group(
|
||||
|
||||
@@ -5,7 +5,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice,
|
||||
|
||||
from letta.constants import PRE_EXECUTION_MESSAGE_ARG
|
||||
from letta.interfaces.utils import _format_sse_chunk
|
||||
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
|
||||
from letta.server.rest_api.json_parser import OptimisticJSONParser
|
||||
|
||||
|
||||
class OpenAIChatCompletionsStreamingInterface:
|
||||
|
||||
@@ -12,7 +12,7 @@ from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
|
||||
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
|
||||
from letta.server.rest_api.json_parser import OptimisticJSONParser
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -28,7 +28,7 @@ from letta.schemas.letta_message import (
|
||||
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
|
||||
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
|
||||
from letta.server.rest_api.json_parser import OptimisticJSONParser
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface
|
||||
from letta.streaming_utils import FunctionArgumentsStreamHandler, JSONInnerThoughtsExtractor
|
||||
from letta.utils import parse_json
|
||||
@@ -291,7 +291,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler(json_key=assistant_message_tool_kwarg)
|
||||
|
||||
# @matt's changes here, adopting new optimistic json parser
|
||||
self.current_function_arguments = []
|
||||
self.current_function_arguments = ""
|
||||
self.optimistic_json_parser = OptimisticJSONParser()
|
||||
self.current_json_parse_result = {}
|
||||
|
||||
@@ -387,7 +387,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
def stream_start(self):
|
||||
"""Initialize streaming by activating the generator and clearing any old chunks."""
|
||||
self.streaming_chat_completion_mode_function_name = None
|
||||
self.current_function_arguments = []
|
||||
self.current_function_arguments = ""
|
||||
self.current_json_parse_result = {}
|
||||
|
||||
if not self._active:
|
||||
@@ -398,7 +398,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
def stream_end(self):
|
||||
"""Clean up the stream by deactivating and clearing chunks."""
|
||||
self.streaming_chat_completion_mode_function_name = None
|
||||
self.current_function_arguments = []
|
||||
self.current_function_arguments = ""
|
||||
self.current_json_parse_result = {}
|
||||
|
||||
# if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
|
||||
@@ -609,14 +609,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# early exit to turn into content mode
|
||||
return None
|
||||
if tool_call.function.arguments:
|
||||
self.current_function_arguments.append(tool_call.function.arguments)
|
||||
self.current_function_arguments += tool_call.function.arguments
|
||||
|
||||
# if we're in the middle of parsing a send_message, we'll keep processing the JSON chunks
|
||||
if tool_call.function.arguments and self.streaming_chat_completion_mode_function_name == self.assistant_message_tool_name:
|
||||
# Strip out any extras tokens
|
||||
# In the case that we just have the prefix of something, no message yet, then we should early exit to move to the next chunk
|
||||
combined_args = "".join(self.current_function_arguments)
|
||||
parsed_args = self.optimistic_json_parser.parse(combined_args)
|
||||
parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments)
|
||||
|
||||
if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get(
|
||||
self.assistant_message_tool_kwarg
|
||||
@@ -686,7 +685,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# updates_inner_thoughts = ""
|
||||
# else: # OpenAI
|
||||
# updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
|
||||
self.current_function_arguments.append(tool_call.function.arguments)
|
||||
self.current_function_arguments += tool_call.function.arguments
|
||||
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
|
||||
|
||||
# If we have inner thoughts, we should output them as a chunk
|
||||
@@ -805,8 +804,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# TODO: THIS IS HORRIBLE
|
||||
# TODO: WE USE THE OLD JSON PARSER EARLIER (WHICH DOES NOTHING) AND NOW THE NEW JSON PARSER
|
||||
# TODO: THIS IS TOTALLY WRONG AND BAD, BUT SAVING FOR A LARGER REWRITE IN THE NEAR FUTURE
|
||||
combined_args = "".join(self.current_function_arguments)
|
||||
parsed_args = self.optimistic_json_parser.parse(combined_args)
|
||||
parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments)
|
||||
|
||||
if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get(
|
||||
self.assistant_message_tool_kwarg
|
||||
|
||||
@@ -1,7 +1,43 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from pydantic_core import from_json
|
||||
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class OptimisticJSONParser:
|
||||
class JSONParser(ABC):
|
||||
@abstractmethod
|
||||
def parse(self, input_str: str) -> Any:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class PydanticJSONParser(JSONParser):
|
||||
"""
|
||||
https://docs.pydantic.dev/latest/concepts/json/#json-parsing
|
||||
If `strict` is True, we will not allow for partial parsing of JSON.
|
||||
|
||||
Compared with `OptimisticJSONParser`, this parser is more strict.
|
||||
Note: This will not partially parse strings which may be decrease parsing speed for message strings
|
||||
"""
|
||||
|
||||
def __init__(self, strict=False):
|
||||
self.strict = strict
|
||||
|
||||
def parse(self, input_str: str) -> Any:
|
||||
if not input_str:
|
||||
return {}
|
||||
try:
|
||||
return from_json(input_str, allow_partial="trailing-strings" if not self.strict else False)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to parse JSON: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class OptimisticJSONParser(JSONParser):
|
||||
"""
|
||||
A JSON parser that attempts to parse a given string using `json.loads`,
|
||||
and if that fails, it parses as much valid JSON as possible while
|
||||
@@ -13,25 +49,25 @@ class OptimisticJSONParser:
|
||||
def __init__(self, strict=False):
|
||||
self.strict = strict
|
||||
self.parsers = {
|
||||
" ": self.parse_space,
|
||||
"\r": self.parse_space,
|
||||
"\n": self.parse_space,
|
||||
"\t": self.parse_space,
|
||||
"[": self.parse_array,
|
||||
"{": self.parse_object,
|
||||
'"': self.parse_string,
|
||||
"t": self.parse_true,
|
||||
"f": self.parse_false,
|
||||
"n": self.parse_null,
|
||||
" ": self._parse_space,
|
||||
"\r": self._parse_space,
|
||||
"\n": self._parse_space,
|
||||
"\t": self._parse_space,
|
||||
"[": self._parse_array,
|
||||
"{": self._parse_object,
|
||||
'"': self._parse_string,
|
||||
"t": self._parse_true,
|
||||
"f": self._parse_false,
|
||||
"n": self._parse_null,
|
||||
}
|
||||
# Register number parser for digits and signs
|
||||
for char in "0123456789.-":
|
||||
self.parsers[char] = self.parse_number
|
||||
|
||||
self.last_parse_reminding = None
|
||||
self.on_extra_token = self.default_on_extra_token
|
||||
self.on_extra_token = self._default_on_extra_token
|
||||
|
||||
def default_on_extra_token(self, text, data, reminding):
|
||||
def _default_on_extra_token(self, text, data, reminding):
|
||||
print(f"Parsed JSON with extra tokens: {data}, remaining: {reminding}")
|
||||
|
||||
def parse(self, input_str):
|
||||
@@ -45,7 +81,7 @@ class OptimisticJSONParser:
|
||||
try:
|
||||
return json.loads(input_str)
|
||||
except json.JSONDecodeError as decode_error:
|
||||
data, reminding = self.parse_any(input_str, decode_error)
|
||||
data, reminding = self._parse_any(input_str, decode_error)
|
||||
self.last_parse_reminding = reminding
|
||||
if self.on_extra_token and reminding:
|
||||
self.on_extra_token(input_str, data, reminding)
|
||||
@@ -53,7 +89,7 @@ class OptimisticJSONParser:
|
||||
else:
|
||||
return json.loads("{}")
|
||||
|
||||
def parse_any(self, input_str, decode_error):
|
||||
def _parse_any(self, input_str, decode_error):
|
||||
"""Determine which parser to use based on the first character."""
|
||||
if not input_str:
|
||||
raise decode_error
|
||||
@@ -62,11 +98,11 @@ class OptimisticJSONParser:
|
||||
raise decode_error
|
||||
return parser(input_str, decode_error)
|
||||
|
||||
def parse_space(self, input_str, decode_error):
|
||||
def _parse_space(self, input_str, decode_error):
|
||||
"""Strip leading whitespace and parse again."""
|
||||
return self.parse_any(input_str.strip(), decode_error)
|
||||
return self._parse_any(input_str.strip(), decode_error)
|
||||
|
||||
def parse_array(self, input_str, decode_error):
|
||||
def _parse_array(self, input_str, decode_error):
|
||||
"""Parse a JSON array, returning the list and remaining string."""
|
||||
# Skip the '['
|
||||
input_str = input_str[1:]
|
||||
@@ -77,7 +113,7 @@ class OptimisticJSONParser:
|
||||
# Skip the ']'
|
||||
input_str = input_str[1:]
|
||||
break
|
||||
value, input_str = self.parse_any(input_str, decode_error)
|
||||
value, input_str = self._parse_any(input_str, decode_error)
|
||||
array_values.append(value)
|
||||
input_str = input_str.strip()
|
||||
if input_str.startswith(","):
|
||||
@@ -85,7 +121,7 @@ class OptimisticJSONParser:
|
||||
input_str = input_str[1:].strip()
|
||||
return array_values, input_str
|
||||
|
||||
def parse_object(self, input_str, decode_error):
|
||||
def _parse_object(self, input_str, decode_error):
|
||||
"""Parse a JSON object, returning the dict and remaining string."""
|
||||
# Skip the '{'
|
||||
input_str = input_str[1:]
|
||||
@@ -96,7 +132,7 @@ class OptimisticJSONParser:
|
||||
# Skip the '}'
|
||||
input_str = input_str[1:]
|
||||
break
|
||||
key, input_str = self.parse_any(input_str, decode_error)
|
||||
key, input_str = self._parse_any(input_str, decode_error)
|
||||
input_str = input_str.strip()
|
||||
|
||||
if not input_str or input_str[0] == "}":
|
||||
@@ -113,7 +149,7 @@ class OptimisticJSONParser:
|
||||
input_str = input_str[1:]
|
||||
break
|
||||
|
||||
value, input_str = self.parse_any(input_str, decode_error)
|
||||
value, input_str = self._parse_any(input_str, decode_error)
|
||||
obj[key] = value
|
||||
input_str = input_str.strip()
|
||||
if input_str.startswith(","):
|
||||
@@ -121,7 +157,7 @@ class OptimisticJSONParser:
|
||||
input_str = input_str[1:].strip()
|
||||
return obj, input_str
|
||||
|
||||
def parse_string(self, input_str, decode_error):
|
||||
def _parse_string(self, input_str, decode_error):
|
||||
"""Parse a JSON string, respecting escaped quotes if present."""
|
||||
end = input_str.find('"', 1)
|
||||
while end != -1 and input_str[end - 1] == "\\":
|
||||
@@ -166,19 +202,19 @@ class OptimisticJSONParser:
|
||||
|
||||
return num, remainder
|
||||
|
||||
def parse_true(self, input_str, decode_error):
|
||||
def _parse_true(self, input_str, decode_error):
|
||||
"""Parse a 'true' value."""
|
||||
if input_str.startswith(("t", "T")):
|
||||
return True, input_str[4:]
|
||||
raise decode_error
|
||||
|
||||
def parse_false(self, input_str, decode_error):
|
||||
def _parse_false(self, input_str, decode_error):
|
||||
"""Parse a 'false' value."""
|
||||
if input_str.startswith(("f", "F")):
|
||||
return False, input_str[5:]
|
||||
raise decode_error
|
||||
|
||||
def parse_null(self, input_str, decode_error):
|
||||
def _parse_null(self, input_str, decode_error):
|
||||
"""Parse a 'null' value."""
|
||||
if input_str.startswith("n"):
|
||||
return None, input_str[4:]
|
||||
@@ -680,7 +680,7 @@ async def send_message_streaming(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
request: LettaStreamingRequest = Body(...),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
) -> StreamingResponse | LettaResponse:
|
||||
"""
|
||||
Process a user message and return the agent's response.
|
||||
This endpoint accepts a message from a user and processes it through the agent.
|
||||
|
||||
@@ -16,6 +16,7 @@ from pydantic import BaseModel
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE
|
||||
from letta.errors import ContextWindowExceededError, RateLimitExceededError
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.message_helper import convert_message_creates_to_messages
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
||||
@@ -143,27 +144,15 @@ def log_error_to_sentry(e):
|
||||
def create_input_messages(input_messages: List[MessageCreate], agent_id: str, actor: User) -> List[Message]:
|
||||
"""
|
||||
Converts a user input message into the internal structured format.
|
||||
"""
|
||||
new_messages = []
|
||||
for input_message in input_messages:
|
||||
# Construct the Message object
|
||||
new_message = Message(
|
||||
id=f"message-{uuid.uuid4()}",
|
||||
role=input_message.role,
|
||||
content=input_message.content,
|
||||
name=input_message.name,
|
||||
otid=input_message.otid,
|
||||
sender_id=input_message.sender_id,
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
created_at=get_utc_time(),
|
||||
)
|
||||
new_messages.append(new_message)
|
||||
|
||||
return new_messages
|
||||
TODO (cliandy): this effectively duplicates the functionality of `convert_message_creates_to_messages`,
|
||||
we should unify this when it's clear what message attributes we need.
|
||||
"""
|
||||
|
||||
messages = convert_message_creates_to_messages(input_messages, agent_id, wrap_user_message=False, wrap_system_message=False)
|
||||
for message in messages:
|
||||
message.organization_id = actor.organization_id
|
||||
return messages
|
||||
|
||||
|
||||
def create_letta_messages_from_llm_response(
|
||||
|
||||
@@ -3,7 +3,7 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
|
||||
from letta.server.rest_api.json_parser import OptimisticJSONParser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
Reference in New Issue
Block a user