fix: user messages on new agent loop are not processed in ADE (includes new json parser) (#1934)

This commit is contained in:
Andy Li
2025-04-30 18:07:42 -07:00
committed by GitHub
parent 57218d2b8f
commit 3d94adbac3
12 changed files with 319 additions and 269 deletions

View File

@@ -28,7 +28,7 @@ from letta.helpers import ToolRulesSolver
from letta.helpers.composio_helpers import get_composio_api_key
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.json_helpers import json_dumps, json_loads
from letta.helpers.message_helper import prepare_input_message_create
from letta.helpers.message_helper import convert_message_creates_to_messages
from letta.interface import AgentInterface
from letta.llm_api.helpers import calculate_summarizer_cutoff, get_token_counts_for_messages, is_context_overflow_error
from letta.llm_api.llm_api_tools import create
@@ -726,8 +726,7 @@ class Agent(BaseAgent):
self.tool_rules_solver.clear_tool_history()
# Convert MessageCreate objects to Message objects
message_objects = [prepare_input_message_create(m, self.agent_state.id, True, True) for m in input_messages]
next_input_messages = message_objects
next_input_messages = convert_message_creates_to_messages(input_messages, self.agent_state.id)
counter = 0
total_usage = UsageStatistics()
step_count = 0

View File

@@ -109,7 +109,7 @@ class LettaAgent(BaseAgent):
)
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
llm_client = LLMClient.create(
llm_config=agent_state.llm_config,
provider=agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=True,
)
@@ -125,7 +125,7 @@ class LettaAgent(BaseAgent):
# TODO: THIS IS INCREDIBLY UGLY
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
interface = AnthropicStreamingInterface(
use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=llm_client.llm_config.put_inner_thoughts_in_kwargs
use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs
)
async for chunk in interface.process(stream):
yield f"data: {chunk.model_dump_json()}\n\n"
@@ -275,45 +275,49 @@ class LettaAgent(BaseAgent):
return persisted_messages, continue_stepping
def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor)
try:
self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor)
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
curr_system_message = in_context_messages[0]
curr_memory_str = agent_state.memory.compile()
curr_system_message_text = curr_system_message.content[0].text
if curr_memory_str in curr_system_message_text:
# NOTE: could this cause issues if a block is removed? (substring match would still work)
logger.debug(
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
)
return in_context_messages
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
curr_system_message = in_context_messages[0]
curr_memory_str = agent_state.memory.compile()
curr_system_message_text = curr_system_message.content[0].text
if curr_memory_str in curr_system_message_text:
# NOTE: could this cause issues if a block is removed? (substring match would still work)
logger.debug(
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
)
return in_context_messages
memory_edit_timestamp = get_utc_time()
memory_edit_timestamp = get_utc_time()
num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
new_system_message_str = compile_system_message(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
previous_message_count=num_messages,
archival_memory_size=num_archival_memories,
)
diff = united_diff(curr_system_message_text, new_system_message_str)
if len(diff) > 0:
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
new_system_message = self.message_manager.update_message_by_id(
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
new_system_message_str = compile_system_message(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
previous_message_count=num_messages,
archival_memory_size=num_archival_memories,
)
# Skip pulling down the agent's memory again to save on a db call
return [new_system_message] + in_context_messages[1:]
diff = united_diff(curr_system_message_text, new_system_message_str)
if len(diff) > 0:
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
else:
return in_context_messages
new_system_message = self.message_manager.update_message_by_id(
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
)
# Skip pulling down the agent's memory again to save on a db call
return [new_system_message] + in_context_messages[1:]
else:
return in_context_messages
except:
logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})")
raise
@trace_method
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]:

View File

@@ -39,10 +39,10 @@ def generate_langchain_tool_wrapper(
) -> tuple[str, str]:
tool_name = tool.__class__.__name__
import_statement = f"from langchain_community.tools import {tool_name}"
extra_module_imports = generate_import_code(additional_imports_module_attr_map)
extra_module_imports = _generate_import_code(additional_imports_module_attr_map)
# Safety check that user has passed in all required imports:
assert_all_classes_are_imported(tool, additional_imports_module_attr_map)
_assert_all_classes_are_imported(tool, additional_imports_module_attr_map)
tool_instantiation = f"tool = {generate_imported_tool_instantiation_call_str(tool)}"
run_call = f"return tool._run(**kwargs)"
@@ -71,7 +71,7 @@ def _assert_code_gen_compilable(code_str):
print(f"Syntax error in code: {e}")
def assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional_imports_module_attr_map: dict[str, str]) -> None:
def _assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional_imports_module_attr_map: dict[str, str]) -> None:
# Safety check that user has passed in all required imports:
tool_name = tool.__class__.__name__
current_class_imports = {tool_name}
@@ -193,7 +193,7 @@ def _is_base_model(obj: Any):
return isinstance(obj, BaseModel)
def generate_import_code(module_attr_map: Optional[dict]):
def _generate_import_code(module_attr_map: Optional[dict]):
if not module_attr_map:
return ""
@@ -295,7 +295,7 @@ async def _send_message_to_agent_no_stream(
return LettaResponse(messages=final_messages, usage=usage_stats)
async def async_send_message_with_retries(
async def _async_send_message_with_retries(
server: "SyncServer",
sender_agent: "Agent",
target_agent_id: str,

View File

@@ -4,7 +4,24 @@ from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message, MessageCreate
def prepare_input_message_create(
def convert_message_creates_to_messages(
messages: list[MessageCreate],
agent_id: str,
wrap_user_message: bool = True,
wrap_system_message: bool = True,
) -> list[Message]:
return [
_convert_message_create_to_message(
message=message,
agent_id=agent_id,
wrap_user_message=wrap_user_message,
wrap_system_message=wrap_system_message,
)
for message in messages
]
def _convert_message_create_to_message(
message: MessageCreate,
agent_id: str,
wrap_user_message: bool = True,
@@ -23,12 +40,12 @@ def prepare_input_message_create(
raise ValueError("Message content is empty or invalid")
# Apply wrapping if needed
if message.role == MessageRole.user and wrap_user_message:
if message.role not in {MessageRole.user, MessageRole.system}:
raise ValueError(f"Invalid message role: {message.role}")
elif message.role == MessageRole.user and wrap_user_message:
message_content = system.package_user_message(user_message=message_content)
elif message.role == MessageRole.system and wrap_system_message:
message_content = system.package_system_message(system_message=message_content)
elif message.role not in {MessageRole.user, MessageRole.system}:
raise ValueError(f"Invalid message role: {message.role}")
return Message(
agent_id=agent_id,

View File

@@ -35,7 +35,7 @@ from letta.schemas.letta_message import (
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser
logger = get_logger(__name__)
@@ -56,7 +56,7 @@ class AnthropicStreamingInterface:
"""
def __init__(self, use_assistant_message: bool = False, put_inner_thoughts_in_kwarg: bool = False):
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
self.json_parser: JSONParser = PydanticJSONParser()
self.use_assistant_message = use_assistant_message
# Premake IDs for database writes
@@ -68,7 +68,7 @@ class AnthropicStreamingInterface:
self.accumulated_inner_thoughts = []
self.tool_call_id = None
self.tool_call_name = None
self.accumulated_tool_call_args = []
self.accumulated_tool_call_args = ""
self.previous_parse = {}
# usage trackers
@@ -85,193 +85,200 @@ class AnthropicStreamingInterface:
def get_tool_call_object(self) -> ToolCall:
"""Useful for agent loop"""
return ToolCall(
id=self.tool_call_id, function=FunctionCall(arguments="".join(self.accumulated_tool_call_args), name=self.tool_call_name)
)
return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=self.accumulated_tool_call_args, name=self.tool_call_name))
def _check_inner_thoughts_complete(self, combined_args: str) -> bool:
"""
Check if inner thoughts are complete in the current tool call arguments
by looking for a closing quote after the inner_thoughts field
"""
if not self.put_inner_thoughts_in_kwarg:
# None of the things should have inner thoughts in kwargs
return True
else:
parsed = self.optimistic_json_parser.parse(combined_args)
# TODO: This will break on tools with 0 input
return len(parsed.keys()) > 1 and INNER_THOUGHTS_KWARG in parsed.keys()
try:
if not self.put_inner_thoughts_in_kwarg:
# None of the things should have inner thoughts in kwargs
return True
else:
parsed = self.json_parser.parse(combined_args)
# TODO: This will break on tools with 0 input
return len(parsed.keys()) > 1 and INNER_THOUGHTS_KWARG in parsed.keys()
except Exception as e:
logger.error("Error checking inner thoughts: %s", e)
raise
async def process(self, stream: AsyncStream[BetaRawMessageStreamEvent]) -> AsyncGenerator[LettaMessage, None]:
async with stream:
async for event in stream:
# TODO: Support BetaThinkingBlock, BetaRedactedThinkingBlock
if isinstance(event, BetaRawContentBlockStartEvent):
content = event.content_block
try:
async with stream:
async for event in stream:
# TODO: Support BetaThinkingBlock, BetaRedactedThinkingBlock
if isinstance(event, BetaRawContentBlockStartEvent):
content = event.content_block
if isinstance(content, BetaTextBlock):
self.anthropic_mode = EventMode.TEXT
# TODO: Can capture citations, etc.
elif isinstance(content, BetaToolUseBlock):
self.anthropic_mode = EventMode.TOOL_USE
self.tool_call_id = content.id
self.tool_call_name = content.name
self.inner_thoughts_complete = False
if isinstance(content, BetaTextBlock):
self.anthropic_mode = EventMode.TEXT
# TODO: Can capture citations, etc.
elif isinstance(content, BetaToolUseBlock):
self.anthropic_mode = EventMode.TOOL_USE
self.tool_call_id = content.id
self.tool_call_name = content.name
self.inner_thoughts_complete = False
if not self.use_assistant_message:
# Buffer the initial tool call message instead of yielding immediately
tool_call_msg = ToolCallMessage(
id=self.letta_tool_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
if not self.use_assistant_message:
# Buffer the initial tool call message instead of yielding immediately
tool_call_msg = ToolCallMessage(
id=self.letta_tool_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
date=datetime.now(timezone.utc).isoformat(),
)
self.tool_call_buffer.append(tool_call_msg)
elif isinstance(content, BetaThinkingBlock):
self.anthropic_mode = EventMode.THINKING
# TODO: Can capture signature, etc.
elif isinstance(content, BetaRedactedThinkingBlock):
self.anthropic_mode = EventMode.REDACTED_THINKING
hidden_reasoning_message = HiddenReasoningMessage(
id=self.letta_assistant_message_id,
state="redacted",
hidden_reasoning=content.data,
date=datetime.now(timezone.utc).isoformat(),
)
self.tool_call_buffer.append(tool_call_msg)
elif isinstance(content, BetaThinkingBlock):
self.anthropic_mode = EventMode.THINKING
# TODO: Can capture signature, etc.
elif isinstance(content, BetaRedactedThinkingBlock):
self.anthropic_mode = EventMode.REDACTED_THINKING
self.reasoning_messages.append(hidden_reasoning_message)
yield hidden_reasoning_message
hidden_reasoning_message = HiddenReasoningMessage(
id=self.letta_assistant_message_id,
state="redacted",
hidden_reasoning=content.data,
date=datetime.now(timezone.utc).isoformat(),
)
self.reasoning_messages.append(hidden_reasoning_message)
yield hidden_reasoning_message
elif isinstance(event, BetaRawContentBlockDeltaEvent):
delta = event.delta
elif isinstance(event, BetaRawContentBlockDeltaEvent):
delta = event.delta
if isinstance(delta, BetaTextDelta):
# Safety check
if not self.anthropic_mode == EventMode.TEXT:
raise RuntimeError(
f"Streaming integrity failed - received BetaTextDelta object while not in TEXT EventMode: {delta}"
)
if isinstance(delta, BetaTextDelta):
# Safety check
if not self.anthropic_mode == EventMode.TEXT:
raise RuntimeError(
f"Streaming integrity failed - received BetaTextDelta object while not in TEXT EventMode: {delta}"
)
# TODO: Strip out </thinking> more robustly, this is pretty hacky lol
delta.text = delta.text.replace("</thinking>", "")
self.accumulated_inner_thoughts.append(delta.text)
# TODO: Strip out </thinking> more robustly, this is pretty hacky lol
delta.text = delta.text.replace("</thinking>", "")
self.accumulated_inner_thoughts.append(delta.text)
reasoning_message = ReasoningMessage(
id=self.letta_assistant_message_id,
reasoning=self.accumulated_inner_thoughts[-1],
date=datetime.now(timezone.utc).isoformat(),
)
self.reasoning_messages.append(reasoning_message)
yield reasoning_message
elif isinstance(delta, BetaInputJSONDelta):
if not self.anthropic_mode == EventMode.TOOL_USE:
raise RuntimeError(
f"Streaming integrity failed - received BetaInputJSONDelta object while not in TOOL_USE EventMode: {delta}"
)
self.accumulated_tool_call_args.append(delta.partial_json)
combined_args = "".join(self.accumulated_tool_call_args)
current_parsed = self.optimistic_json_parser.parse(combined_args)
# Start detecting a difference in inner thoughts
previous_inner_thoughts = self.previous_parse.get(INNER_THOUGHTS_KWARG, "")
current_inner_thoughts = current_parsed.get(INNER_THOUGHTS_KWARG, "")
inner_thoughts_diff = current_inner_thoughts[len(previous_inner_thoughts) :]
if inner_thoughts_diff:
reasoning_message = ReasoningMessage(
id=self.letta_assistant_message_id,
reasoning=inner_thoughts_diff,
reasoning=self.accumulated_inner_thoughts[-1],
date=datetime.now(timezone.utc).isoformat(),
)
self.reasoning_messages.append(reasoning_message)
yield reasoning_message
# Check if inner thoughts are complete - if so, flush the buffer
if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(combined_args):
self.inner_thoughts_complete = True
# Flush all buffered tool call messages
elif isinstance(delta, BetaInputJSONDelta):
if not self.anthropic_mode == EventMode.TOOL_USE:
raise RuntimeError(
f"Streaming integrity failed - received BetaInputJSONDelta object while not in TOOL_USE EventMode: {delta}"
)
self.accumulated_tool_call_args += delta.partial_json
current_parsed = self.json_parser.parse(self.accumulated_tool_call_args)
# Start detecting a difference in inner thoughts
previous_inner_thoughts = self.previous_parse.get(INNER_THOUGHTS_KWARG, "")
current_inner_thoughts = current_parsed.get(INNER_THOUGHTS_KWARG, "")
inner_thoughts_diff = current_inner_thoughts[len(previous_inner_thoughts) :]
if inner_thoughts_diff:
reasoning_message = ReasoningMessage(
id=self.letta_assistant_message_id,
reasoning=inner_thoughts_diff,
date=datetime.now(timezone.utc).isoformat(),
)
self.reasoning_messages.append(reasoning_message)
yield reasoning_message
# Check if inner thoughts are complete - if so, flush the buffer
if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(self.accumulated_tool_call_args):
self.inner_thoughts_complete = True
# Flush all buffered tool call messages
for buffered_msg in self.tool_call_buffer:
yield buffered_msg
self.tool_call_buffer = []
# Start detecting special case of "send_message"
if self.tool_call_name == DEFAULT_MESSAGE_TOOL and self.use_assistant_message:
previous_send_message = self.previous_parse.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
current_send_message = current_parsed.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
send_message_diff = current_send_message[len(previous_send_message) :]
# Only stream out if it's not an empty string
if send_message_diff:
yield AssistantMessage(
id=self.letta_assistant_message_id,
content=[TextContent(text=send_message_diff)],
date=datetime.now(timezone.utc).isoformat(),
)
else:
# Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status
tool_call_msg = ToolCallMessage(
id=self.letta_tool_message_id,
tool_call=ToolCallDelta(arguments=delta.partial_json),
date=datetime.now(timezone.utc).isoformat(),
)
if self.inner_thoughts_complete:
yield tool_call_msg
else:
self.tool_call_buffer.append(tool_call_msg)
# Set previous parse
self.previous_parse = current_parsed
elif isinstance(delta, BetaThinkingDelta):
# Safety check
if not self.anthropic_mode == EventMode.THINKING:
raise RuntimeError(
f"Streaming integrity failed - received BetaThinkingBlock object while not in THINKING EventMode: {delta}"
)
reasoning_message = ReasoningMessage(
id=self.letta_assistant_message_id,
source="reasoner_model",
reasoning=delta.thinking,
date=datetime.now(timezone.utc).isoformat(),
)
self.reasoning_messages.append(reasoning_message)
yield reasoning_message
elif isinstance(delta, BetaSignatureDelta):
# Safety check
if not self.anthropic_mode == EventMode.THINKING:
raise RuntimeError(
f"Streaming integrity failed - received BetaSignatureDelta object while not in THINKING EventMode: {delta}"
)
reasoning_message = ReasoningMessage(
id=self.letta_assistant_message_id,
source="reasoner_model",
reasoning="",
date=datetime.now(timezone.utc).isoformat(),
signature=delta.signature,
)
self.reasoning_messages.append(reasoning_message)
yield reasoning_message
elif isinstance(event, BetaRawMessageStartEvent):
self.message_id = event.message.id
self.input_tokens += event.message.usage.input_tokens
self.output_tokens += event.message.usage.output_tokens
elif isinstance(event, BetaRawMessageDeltaEvent):
self.output_tokens += event.usage.output_tokens
elif isinstance(event, BetaRawMessageStopEvent):
# Don't do anything here! We don't want to stop the stream.
pass
elif isinstance(event, BetaRawContentBlockStopEvent):
# If we're exiting a tool use block and there are still buffered messages,
# we should flush them now
if self.anthropic_mode == EventMode.TOOL_USE and self.tool_call_buffer:
for buffered_msg in self.tool_call_buffer:
yield buffered_msg
self.tool_call_buffer = []
# Start detecting special case of "send_message"
if self.tool_call_name == DEFAULT_MESSAGE_TOOL and self.use_assistant_message:
previous_send_message = self.previous_parse.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
current_send_message = current_parsed.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
send_message_diff = current_send_message[len(previous_send_message) :]
# Only stream out if it's not an empty string
if send_message_diff:
yield AssistantMessage(
id=self.letta_assistant_message_id,
content=[TextContent(text=send_message_diff)],
date=datetime.now(timezone.utc).isoformat(),
)
else:
# Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status
tool_call_msg = ToolCallMessage(
id=self.letta_tool_message_id,
tool_call=ToolCallDelta(arguments=delta.partial_json),
date=datetime.now(timezone.utc).isoformat(),
)
if self.inner_thoughts_complete:
yield tool_call_msg
else:
self.tool_call_buffer.append(tool_call_msg)
# Set previous parse
self.previous_parse = current_parsed
elif isinstance(delta, BetaThinkingDelta):
# Safety check
if not self.anthropic_mode == EventMode.THINKING:
raise RuntimeError(
f"Streaming integrity failed - received BetaThinkingBlock object while not in THINKING EventMode: {delta}"
)
reasoning_message = ReasoningMessage(
id=self.letta_assistant_message_id,
source="reasoner_model",
reasoning=delta.thinking,
date=datetime.now(timezone.utc).isoformat(),
)
self.reasoning_messages.append(reasoning_message)
yield reasoning_message
elif isinstance(delta, BetaSignatureDelta):
# Safety check
if not self.anthropic_mode == EventMode.THINKING:
raise RuntimeError(
f"Streaming integrity failed - received BetaSignatureDelta object while not in THINKING EventMode: {delta}"
)
reasoning_message = ReasoningMessage(
id=self.letta_assistant_message_id,
source="reasoner_model",
reasoning="",
date=datetime.now(timezone.utc).isoformat(),
signature=delta.signature,
)
self.reasoning_messages.append(reasoning_message)
yield reasoning_message
elif isinstance(event, BetaRawMessageStartEvent):
self.message_id = event.message.id
self.input_tokens += event.message.usage.input_tokens
self.output_tokens += event.message.usage.output_tokens
elif isinstance(event, BetaRawMessageDeltaEvent):
self.output_tokens += event.usage.output_tokens
elif isinstance(event, BetaRawMessageStopEvent):
# Don't do anything here! We don't want to stop the stream.
pass
elif isinstance(event, BetaRawContentBlockStopEvent):
# If we're exiting a tool use block and there are still buffered messages,
# we should flush them now
if self.anthropic_mode == EventMode.TOOL_USE and self.tool_call_buffer:
for buffered_msg in self.tool_call_buffer:
yield buffered_msg
self.tool_call_buffer = []
self.anthropic_mode = None
self.anthropic_mode = None
except Exception as e:
logger.error("Error processing stream: %s", e)
raise
finally:
logger.info("AnthropicStreamingInterface: Stream processing complete.")
def get_reasoning_content(self) -> List[Union[TextContent, ReasoningContent, RedactedReasoningContent]]:
def _process_group(

View File

@@ -5,7 +5,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice,
from letta.constants import PRE_EXECUTION_MESSAGE_ARG
from letta.interfaces.utils import _format_sse_chunk
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
from letta.server.rest_api.json_parser import OptimisticJSONParser
class OpenAIChatCompletionsStreamingInterface:

View File

@@ -12,7 +12,7 @@ from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import LettaMessage
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
from letta.server.rest_api.json_parser import OptimisticJSONParser
from letta.streaming_interface import AgentChunkStreamingInterface
logger = get_logger(__name__)

View File

@@ -28,7 +28,7 @@ from letta.schemas.letta_message import (
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
from letta.server.rest_api.json_parser import OptimisticJSONParser
from letta.streaming_interface import AgentChunkStreamingInterface
from letta.streaming_utils import FunctionArgumentsStreamHandler, JSONInnerThoughtsExtractor
from letta.utils import parse_json
@@ -291,7 +291,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler(json_key=assistant_message_tool_kwarg)
# @matt's changes here, adopting new optimistic json parser
self.current_function_arguments = []
self.current_function_arguments = ""
self.optimistic_json_parser = OptimisticJSONParser()
self.current_json_parse_result = {}
@@ -387,7 +387,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
def stream_start(self):
"""Initialize streaming by activating the generator and clearing any old chunks."""
self.streaming_chat_completion_mode_function_name = None
self.current_function_arguments = []
self.current_function_arguments = ""
self.current_json_parse_result = {}
if not self._active:
@@ -398,7 +398,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
def stream_end(self):
"""Clean up the stream by deactivating and clearing chunks."""
self.streaming_chat_completion_mode_function_name = None
self.current_function_arguments = []
self.current_function_arguments = ""
self.current_json_parse_result = {}
# if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
@@ -609,14 +609,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# early exit to turn into content mode
return None
if tool_call.function.arguments:
self.current_function_arguments.append(tool_call.function.arguments)
self.current_function_arguments += tool_call.function.arguments
# if we're in the middle of parsing a send_message, we'll keep processing the JSON chunks
if tool_call.function.arguments and self.streaming_chat_completion_mode_function_name == self.assistant_message_tool_name:
# Strip out any extras tokens
# In the case that we just have the prefix of something, no message yet, then we should early exit to move to the next chunk
combined_args = "".join(self.current_function_arguments)
parsed_args = self.optimistic_json_parser.parse(combined_args)
parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments)
if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get(
self.assistant_message_tool_kwarg
@@ -686,7 +685,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# updates_inner_thoughts = ""
# else: # OpenAI
# updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
self.current_function_arguments.append(tool_call.function.arguments)
self.current_function_arguments += tool_call.function.arguments
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
# If we have inner thoughts, we should output them as a chunk
@@ -805,8 +804,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# TODO: THIS IS HORRIBLE
# TODO: WE USE THE OLD JSON PARSER EARLIER (WHICH DOES NOTHING) AND NOW THE NEW JSON PARSER
# TODO: THIS IS TOTALLY WRONG AND BAD, BUT SAVING FOR A LARGER REWRITE IN THE NEAR FUTURE
combined_args = "".join(self.current_function_arguments)
parsed_args = self.optimistic_json_parser.parse(combined_args)
parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments)
if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get(
self.assistant_message_tool_kwarg

View File

@@ -1,7 +1,43 @@
import json
from abc import ABC, abstractmethod
from typing import Any
from pydantic_core import from_json
from letta.log import get_logger
logger = get_logger(__name__)
class OptimisticJSONParser:
class JSONParser(ABC):
@abstractmethod
def parse(self, input_str: str) -> Any:
raise NotImplementedError()
class PydanticJSONParser(JSONParser):
"""
https://docs.pydantic.dev/latest/concepts/json/#json-parsing
If `strict` is True, we will not allow for partial parsing of JSON.
Compared with `OptimisticJSONParser`, this parser is more strict.
Note: This will not partially parse strings which may be decrease parsing speed for message strings
"""
def __init__(self, strict=False):
self.strict = strict
def parse(self, input_str: str) -> Any:
if not input_str:
return {}
try:
return from_json(input_str, allow_partial="trailing-strings" if not self.strict else False)
except ValueError as e:
logger.error(f"Failed to parse JSON: {e}")
raise
class OptimisticJSONParser(JSONParser):
"""
A JSON parser that attempts to parse a given string using `json.loads`,
and if that fails, it parses as much valid JSON as possible while
@@ -13,25 +49,25 @@ class OptimisticJSONParser:
def __init__(self, strict=False):
self.strict = strict
self.parsers = {
" ": self.parse_space,
"\r": self.parse_space,
"\n": self.parse_space,
"\t": self.parse_space,
"[": self.parse_array,
"{": self.parse_object,
'"': self.parse_string,
"t": self.parse_true,
"f": self.parse_false,
"n": self.parse_null,
" ": self._parse_space,
"\r": self._parse_space,
"\n": self._parse_space,
"\t": self._parse_space,
"[": self._parse_array,
"{": self._parse_object,
'"': self._parse_string,
"t": self._parse_true,
"f": self._parse_false,
"n": self._parse_null,
}
# Register number parser for digits and signs
for char in "0123456789.-":
self.parsers[char] = self.parse_number
self.last_parse_reminding = None
self.on_extra_token = self.default_on_extra_token
self.on_extra_token = self._default_on_extra_token
def default_on_extra_token(self, text, data, reminding):
def _default_on_extra_token(self, text, data, reminding):
print(f"Parsed JSON with extra tokens: {data}, remaining: {reminding}")
def parse(self, input_str):
@@ -45,7 +81,7 @@ class OptimisticJSONParser:
try:
return json.loads(input_str)
except json.JSONDecodeError as decode_error:
data, reminding = self.parse_any(input_str, decode_error)
data, reminding = self._parse_any(input_str, decode_error)
self.last_parse_reminding = reminding
if self.on_extra_token and reminding:
self.on_extra_token(input_str, data, reminding)
@@ -53,7 +89,7 @@ class OptimisticJSONParser:
else:
return json.loads("{}")
def parse_any(self, input_str, decode_error):
def _parse_any(self, input_str, decode_error):
"""Determine which parser to use based on the first character."""
if not input_str:
raise decode_error
@@ -62,11 +98,11 @@ class OptimisticJSONParser:
raise decode_error
return parser(input_str, decode_error)
def parse_space(self, input_str, decode_error):
def _parse_space(self, input_str, decode_error):
"""Strip leading whitespace and parse again."""
return self.parse_any(input_str.strip(), decode_error)
return self._parse_any(input_str.strip(), decode_error)
def parse_array(self, input_str, decode_error):
def _parse_array(self, input_str, decode_error):
"""Parse a JSON array, returning the list and remaining string."""
# Skip the '['
input_str = input_str[1:]
@@ -77,7 +113,7 @@ class OptimisticJSONParser:
# Skip the ']'
input_str = input_str[1:]
break
value, input_str = self.parse_any(input_str, decode_error)
value, input_str = self._parse_any(input_str, decode_error)
array_values.append(value)
input_str = input_str.strip()
if input_str.startswith(","):
@@ -85,7 +121,7 @@ class OptimisticJSONParser:
input_str = input_str[1:].strip()
return array_values, input_str
def parse_object(self, input_str, decode_error):
def _parse_object(self, input_str, decode_error):
"""Parse a JSON object, returning the dict and remaining string."""
# Skip the '{'
input_str = input_str[1:]
@@ -96,7 +132,7 @@ class OptimisticJSONParser:
# Skip the '}'
input_str = input_str[1:]
break
key, input_str = self.parse_any(input_str, decode_error)
key, input_str = self._parse_any(input_str, decode_error)
input_str = input_str.strip()
if not input_str or input_str[0] == "}":
@@ -113,7 +149,7 @@ class OptimisticJSONParser:
input_str = input_str[1:]
break
value, input_str = self.parse_any(input_str, decode_error)
value, input_str = self._parse_any(input_str, decode_error)
obj[key] = value
input_str = input_str.strip()
if input_str.startswith(","):
@@ -121,7 +157,7 @@ class OptimisticJSONParser:
input_str = input_str[1:].strip()
return obj, input_str
def parse_string(self, input_str, decode_error):
def _parse_string(self, input_str, decode_error):
"""Parse a JSON string, respecting escaped quotes if present."""
end = input_str.find('"', 1)
while end != -1 and input_str[end - 1] == "\\":
@@ -166,19 +202,19 @@ class OptimisticJSONParser:
return num, remainder
def parse_true(self, input_str, decode_error):
def _parse_true(self, input_str, decode_error):
"""Parse a 'true' value."""
if input_str.startswith(("t", "T")):
return True, input_str[4:]
raise decode_error
def parse_false(self, input_str, decode_error):
def _parse_false(self, input_str, decode_error):
"""Parse a 'false' value."""
if input_str.startswith(("f", "F")):
return False, input_str[5:]
raise decode_error
def parse_null(self, input_str, decode_error):
def _parse_null(self, input_str, decode_error):
"""Parse a 'null' value."""
if input_str.startswith("n"):
return None, input_str[4:]

View File

@@ -680,7 +680,7 @@ async def send_message_streaming(
server: SyncServer = Depends(get_letta_server),
request: LettaStreamingRequest = Body(...),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
) -> StreamingResponse | LettaResponse:
"""
Process a user message and return the agent's response.
This endpoint accepts a message from a user and processes it through the agent.

View File

@@ -16,6 +16,7 @@ from pydantic import BaseModel
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE
from letta.errors import ContextWindowExceededError, RateLimitExceededError
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.message_helper import convert_message_creates_to_messages
from letta.log import get_logger
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
@@ -143,27 +144,15 @@ def log_error_to_sentry(e):
def create_input_messages(input_messages: List[MessageCreate], agent_id: str, actor: User) -> List[Message]:
"""
Converts a user input message into the internal structured format.
"""
new_messages = []
for input_message in input_messages:
# Construct the Message object
new_message = Message(
id=f"message-{uuid.uuid4()}",
role=input_message.role,
content=input_message.content,
name=input_message.name,
otid=input_message.otid,
sender_id=input_message.sender_id,
organization_id=actor.organization_id,
agent_id=agent_id,
model=None,
tool_calls=None,
tool_call_id=None,
created_at=get_utc_time(),
)
new_messages.append(new_message)
return new_messages
TODO (cliandy): this effectively duplicates the functionality of `convert_message_creates_to_messages`,
we should unify this when it's clear what message attributes we need.
"""
messages = convert_message_creates_to_messages(input_messages, agent_id, wrap_user_message=False, wrap_system_message=False)
for message in messages:
message.organization_id = actor.organization_id
return messages
def create_letta_messages_from_llm_response(

View File

@@ -3,7 +3,7 @@ from unittest.mock import patch
import pytest
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
from letta.server.rest_api.json_parser import OptimisticJSONParser
@pytest.fixture