From 9458f40d05007a2dde01f35ceea1e5671f70b6fd Mon Sep 17 00:00:00 2001 From: cthomas Date: Fri, 4 Apr 2025 08:43:01 -0700 Subject: [PATCH] feat: add otid field for message idempotency (#1556) --- letta/agent.py | 25 +++++++---- letta/functions/interface.py | 4 +- letta/groups/background_multi_agent.py | 1 + letta/groups/dynamic_multi_agent.py | 3 ++ letta/groups/round_robin_multi_agent.py | 2 + letta/groups/supervisor_multi_agent.py | 1 + letta/helpers/message_helper.py | 1 + letta/interface.py | 17 +++++--- letta/llm_api/anthropic.py | 1 + letta/llm_api/openai.py | 1 + letta/schemas/letta_message.py | 2 + letta/schemas/message.py | 41 +++++++++++++++++++ .../rest_api/chat_completions_interface.py | 4 +- letta/server/rest_api/interface.py | 40 +++++++++++++++--- letta/server/server.py | 1 + letta/streaming_interface.py | 23 +++++++---- tests/test_streaming.py | 18 +++++++- 17 files changed, 152 insertions(+), 33 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 978197e7..1c227af0 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -130,6 +130,7 @@ class Agent(BaseAgent): # Different interfaces can handle events differently # e.g., print in CLI vs send a discord message with a discord bot self.interface = interface + self.chunk_index = 0 # Create the persistence manager object based on the AgentState info self.message_manager = MessageManager() @@ -246,9 +247,11 @@ class Agent(BaseAgent): group_id=group_id, ) messages.append(new_message) - self.interface.function_message(f"Error: {error_msg}", msg_obj=new_message) + self.interface.function_message(f"Error: {error_msg}", msg_obj=new_message, chunk_index=self.chunk_index) + self.chunk_index += 1 if include_function_failed_message: - self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=new_message) + self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=new_message, chunk_index=self.chunk_index) + self.chunk_index += 1 # Return updated messages return messages @@ -430,7 +433,8 @@ class Agent(BaseAgent): nonnull_content = False if response_message.content or response_message.reasoning_content or response_message.redacted_reasoning_content: # The content if then internal monologue, not chat - self.interface.internal_monologue(response_message.content, msg_obj=messages[-1]) + self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=self.chunk_index) + self.chunk_index += 1 # Flag to avoid printing a duplicate if inner thoughts get popped from the function call nonnull_content = True @@ -479,7 +483,8 @@ class Agent(BaseAgent): response_message.content = function_args.pop("inner_thoughts") # The content if then internal monologue, not chat if response_message.content and not nonnull_content: - self.interface.internal_monologue(response_message.content, msg_obj=messages[-1]) + self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=self.chunk_index) + self.chunk_index += 1 # (Still parsing function args) # Handle requests for immediate heartbeat @@ -501,7 +506,8 @@ class Agent(BaseAgent): # Failure case 3: function failed during execution # NOTE: the msg_obj associated with the "Running " message is the prior assistant message, not the function/tool role message # this is because the function/tool role message is only created once the function/tool has executed/returned - self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1]) + self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1], chunk_index=self.chunk_index) + self.chunk_index += 1 try: # handle tool execution (sandbox) and state updates log_telemetry( @@ -634,8 +640,10 @@ class Agent(BaseAgent): group_id=group_id, ) ) # extend conversation with function response - self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1]) - self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1]) + self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1], chunk_index=self.chunk_index) + self.chunk_index += 1 + self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1], chunk_index=self.chunk_index) + self.chunk_index += 1 self.last_function_response = function_response else: @@ -651,7 +659,8 @@ class Agent(BaseAgent): group_id=group_id, ) ) # extend conversation with assistant's reply - self.interface.internal_monologue(response_message.content, msg_obj=messages[-1]) + self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=self.chunk_index) + self.chunk_index += 1 heartbeat_request = False function_failed = False diff --git a/letta/functions/interface.py b/letta/functions/interface.py index 82bf229e..2e284de3 100644 --- a/letta/functions/interface.py +++ b/letta/functions/interface.py @@ -18,13 +18,13 @@ class MultiAgentMessagingInterface(AgentInterface): self._captured_messages: List[AssistantMessage] = [] self.metadata = {} - def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None): + def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None): """Ignore internal monologue.""" def assistant_message(self, msg: str, msg_obj: Optional[Message] = None): """Ignore normal assistant messages (only capturing send_message calls).""" - def function_message(self, msg: str, msg_obj: Optional[Message] = None): + def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None): """ Called whenever the agent logs a function call. We'll inspect msg_obj.tool_calls: - If tool_calls include a function named 'send_message', parse its arguments diff --git a/letta/groups/background_multi_agent.py b/letta/groups/background_multi_agent.py index 13be7f70..64cba869 100644 --- a/letta/groups/background_multi_agent.py +++ b/letta/groups/background_multi_agent.py @@ -192,6 +192,7 @@ class BackgroundMultiAgent(Agent): tool_calls=None, tool_call_id=None, group_id=self.group_id, + otid=message.otid, ) for message in messages ] diff --git a/letta/groups/dynamic_multi_agent.py b/letta/groups/dynamic_multi_agent.py index c807efa7..9f0973ea 100644 --- a/letta/groups/dynamic_multi_agent.py +++ b/letta/groups/dynamic_multi_agent.py @@ -99,6 +99,7 @@ class DynamicMultiAgent(Agent): tool_calls=None, tool_call_id=None, group_id=self.group_id, + otid=message.otid, ) ) @@ -125,6 +126,7 @@ class DynamicMultiAgent(Agent): role="system", content=message.content, name=participant_agent.agent_state.name, + otid=message.otid, ) for message in assistant_messages ] @@ -271,4 +273,5 @@ class DynamicMultiAgent(Agent): tool_calls=None, tool_call_id=None, group_id=self.group_id, + otid=Message.generate_otid(), ) diff --git a/letta/groups/round_robin_multi_agent.py b/letta/groups/round_robin_multi_agent.py index 9bb62146..4a9bcaaa 100644 --- a/letta/groups/round_robin_multi_agent.py +++ b/letta/groups/round_robin_multi_agent.py @@ -69,6 +69,7 @@ class RoundRobinMultiAgent(Agent): tool_calls=None, tool_call_id=None, group_id=self.group_id, + otid=message.otid, ) ) @@ -92,6 +93,7 @@ class RoundRobinMultiAgent(Agent): role="system", content=message.content, name=message.name, + otid=message.otid, ) for message in assistant_messages ] diff --git a/letta/groups/supervisor_multi_agent.py b/letta/groups/supervisor_multi_agent.py index 98c3d7ab..bdd8f4f9 100644 --- a/letta/groups/supervisor_multi_agent.py +++ b/letta/groups/supervisor_multi_agent.py @@ -89,6 +89,7 @@ class SupervisorMultiAgent(Agent): tool_calls=None, tool_call_id=None, group_id=self.group_id, + otid=message.otid, ) for message in messages ] diff --git a/letta/helpers/message_helper.py b/letta/helpers/message_helper.py index 6f9fe5ea..5f040ced 100644 --- a/letta/helpers/message_helper.py +++ b/letta/helpers/message_helper.py @@ -38,4 +38,5 @@ def prepare_input_message_create( model=None, # assigned later? tool_calls=None, # irrelevant tool_call_id=None, + otid=message.otid, ) diff --git a/letta/interface.py b/letta/interface.py index 9e0acbd2..281274de 100644 --- a/letta/interface.py +++ b/letta/interface.py @@ -30,7 +30,7 @@ class AgentInterface(ABC): raise NotImplementedError @abstractmethod - def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None): + def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None): """Letta generates some internal monologue""" raise NotImplementedError @@ -40,7 +40,7 @@ class AgentInterface(ABC): raise NotImplementedError @abstractmethod - def function_message(self, msg: str, msg_obj: Optional[Message] = None): + def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None): """Letta calls a function""" raise NotImplementedError @@ -79,7 +79,7 @@ class CLIInterface(AgentInterface): print(fstr.format(msg=msg)) @staticmethod - def internal_monologue(msg: str, msg_obj: Optional[Message] = None): + def internal_monologue(msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None): # ANSI escape code for italic is '\x1B[3m' fstr = f"\x1B[3m{Fore.LIGHTBLACK_EX}{INNER_THOUGHTS_CLI_SYMBOL} {{msg}}{Style.RESET_ALL}" if STRIP_UI: @@ -108,7 +108,14 @@ class CLIInterface(AgentInterface): print(fstr.format(msg=msg)) @staticmethod - def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG): + def user_message( + msg: str, + msg_obj: Optional[Message] = None, + raw: bool = False, + dump: bool = False, + debug: bool = DEBUG, + chunk_index: Optional[int] = None, + ): def print_user_message(icon, msg, printf=print): if STRIP_UI: printf(f"{icon} {msg}") @@ -154,7 +161,7 @@ class CLIInterface(AgentInterface): printd_user_message("🧑", msg_json) @staticmethod - def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG): + def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG, chunk_index: Optional[int] = None): def print_function_message(icon, msg, color=Fore.RED, printf=print): if STRIP_UI: printf(f"⚡{icon} [function] {msg}") diff --git a/letta/llm_api/anthropic.py b/letta/llm_api/anthropic.py index 953ae530..1fa1e2cf 100644 --- a/letta/llm_api/anthropic.py +++ b/letta/llm_api/anthropic.py @@ -953,6 +953,7 @@ def anthropic_chat_completions_process_stream( # TODO handle emitting redacted reasoning content (e.g. as concat?) expect_reasoning_content=extended_thinking, name=name, + chunk_index=chunk_idx, ) elif isinstance(stream_interface, AgentRefreshStreamingInterface): stream_interface.process_refresh(chat_completion_response) diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index a99ffb78..8d4127ae 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -274,6 +274,7 @@ def openai_chat_completions_process_stream( message_date=chat_completion_response.created if create_message_datetime else chat_completion_chunk.created, expect_reasoning_content=expect_reasoning_content, name=name, + chunk_index=chunk_idx, ) elif isinstance(stream_interface, AgentRefreshStreamingInterface): stream_interface.process_refresh(chat_completion_response) diff --git a/letta/schemas/letta_message.py b/letta/schemas/letta_message.py index ffaaa8ea..ec58d8c6 100644 --- a/letta/schemas/letta_message.py +++ b/letta/schemas/letta_message.py @@ -26,11 +26,13 @@ class LettaMessage(BaseModel): id (str): The ID of the message date (datetime): The date the message was created in ISO format name (Optional[str]): The name of the sender of the message + otid (Optional[str]): The offline threading id associated with this message """ id: str date: datetime name: Optional[str] = None + otid: Optional[str] = None @field_serializer("date") def serialize_datetime(self, dt: datetime, _info): diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 4881ab0e..2e6c6372 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -2,6 +2,7 @@ from __future__ import annotations import copy import json +import uuid import warnings from collections import OrderedDict from datetime import datetime, timezone @@ -78,6 +79,7 @@ class MessageCreate(BaseModel): json_schema_extra=get_letta_message_content_union_str_json_schema(), ) name: Optional[str] = Field(None, description="The name of the participant.") + otid: Optional[str] = Field(None, description="The offline threading id associated with this message") def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]: data = super().model_dump(**kwargs) @@ -168,12 +170,17 @@ class Message(BaseMessage): json_message["created_at"] = self.created_at.isoformat() return json_message + @staticmethod + def generate_otid(): + return str(uuid.uuid4()) + @staticmethod def to_letta_messages_from_list( messages: List[Message], use_assistant_message: bool = True, assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL, assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG, + reverse: bool = True, ) -> List[LettaMessage]: if use_assistant_message: message_ids_to_remove = [] @@ -203,6 +210,7 @@ class Message(BaseMessage): use_assistant_message=use_assistant_message, assistant_message_tool_name=assistant_message_tool_name, assistant_message_tool_kwarg=assistant_message_tool_kwarg, + reverse=reverse, ) ] @@ -211,6 +219,7 @@ class Message(BaseMessage): use_assistant_message: bool = False, assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL, assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG, + reverse: bool = True, ) -> List[LettaMessage]: """Convert message object (in DB format) to the style used by the original Letta API""" messages = [] @@ -221,18 +230,21 @@ class Message(BaseMessage): if self.content: # Check for ReACT-style COT inside of TextContent if len(self.content) == 1 and isinstance(self.content[0], TextContent): + otid = Message.generate_otid_from_id(self.id, len(messages)) messages.append( ReasoningMessage( id=self.id, date=self.created_at, reasoning=self.content[0].text, name=self.name, + otid=otid, ) ) # Otherwise, we may have a list of multiple types else: # TODO we can probably collapse these two cases into a single loop for content_part in self.content: + otid = Message.generate_otid_from_id(self.id, len(messages)) if isinstance(content_part, TextContent): # COT messages.append( @@ -241,6 +253,7 @@ class Message(BaseMessage): date=self.created_at, reasoning=content_part.text, name=self.name, + otid=otid, ) ) elif isinstance(content_part, ReasoningContent): @@ -253,6 +266,7 @@ class Message(BaseMessage): source="reasoner_model", # TODO do we want to tag like this? signature=content_part.signature, name=self.name, + otid=otid, ) ) elif isinstance(content_part, RedactedReasoningContent): @@ -264,6 +278,7 @@ class Message(BaseMessage): state="redacted", hidden_reasoning=content_part.data, name=self.name, + otid=otid, ) ) else: @@ -272,6 +287,7 @@ class Message(BaseMessage): if self.tool_calls is not None: # This is type FunctionCall for tool_call in self.tool_calls: + otid = Message.generate_otid_from_id(self.id, len(messages)) # If we're supporting using assistant message, # then we want to treat certain function calls as a special case if use_assistant_message and tool_call.function.name == assistant_message_tool_name: @@ -287,6 +303,7 @@ class Message(BaseMessage): date=self.created_at, content=message_string, name=self.name, + otid=otid, ) ) else: @@ -300,6 +317,7 @@ class Message(BaseMessage): tool_call_id=tool_call.id, ), name=self.name, + otid=otid, ) ) elif self.role == MessageRole.tool: @@ -341,6 +359,7 @@ class Message(BaseMessage): stdout=self.tool_returns[0].stdout if self.tool_returns else None, stderr=self.tool_returns[0].stderr if self.tool_returns else None, name=self.name, + otid=self.id.replace("message-", ""), ) ) elif self.role == MessageRole.user: @@ -357,6 +376,7 @@ class Message(BaseMessage): date=self.created_at, content=message_str or text_content, name=self.name, + otid=self.otid, ) ) elif self.role == MessageRole.system: @@ -372,11 +392,15 @@ class Message(BaseMessage): date=self.created_at, content=text_content, name=self.name, + otid=self.otid, ) ) else: raise ValueError(self.role) + if reverse: + messages.reverse() + return messages @staticmethod @@ -991,6 +1015,23 @@ class Message(BaseMessage): return cohere_message + @staticmethod + def generate_otid_from_id(message_id: str, index: int) -> str: + """ + Convert message id to bits and change the list bit to the index + """ + if not 0 <= index < 128: + raise ValueError("Index must be between 0 and 127") + + message_uuid = message_id.replace("message-", "") + uuid_int = int(message_uuid.replace("-", ""), 16) + + # Clear last 7 bits and set them to index; supports up to 128 unique indices + uuid_int = (uuid_int & ~0x7F) | (index & 0x7F) + + hex_str = f"{uuid_int:032x}" + return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:]}" + class ToolReturn(BaseModel): status: Literal["success", "error"] = Field(..., description="The status of the tool call") diff --git a/letta/server/rest_api/chat_completions_interface.py b/letta/server/rest_api/chat_completions_interface.py index 5db367c6..4b5730b3 100644 --- a/letta/server/rest_api/chat_completions_interface.py +++ b/letta/server/rest_api/chat_completions_interface.py @@ -172,7 +172,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface): """ return - def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None) -> None: + def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None) -> None: """ Handle LLM reasoning or internal monologue. Example usage: if you want to capture chain-of-thought for debugging in a non-streaming scenario. @@ -186,7 +186,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface): """ return - def function_message(self, msg: str, msg_obj: Optional[Message] = None) -> None: + def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None) -> None: """ Handle function-related log messages, typically of the form: It's a no-op by default. diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index 405eb476..a48bc57c 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -165,7 +165,7 @@ class QueuingInterface(AgentInterface): print(vars(msg_obj)) print(msg_obj.created_at.isoformat()) - def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None) -> None: + def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None) -> None: """Handle the agent's internal monologue""" assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata" if self.debug: @@ -209,7 +209,9 @@ class QueuingInterface(AgentInterface): self._queue_push(message_api=new_message, message_obj=msg_obj) - def function_message(self, msg: str, msg_obj: Optional[Message] = None, include_ran_messages: bool = False) -> None: + def function_message( + self, msg: str, msg_obj: Optional[Message] = None, include_ran_messages: bool = False, chunk_index: Optional[int] = None + ) -> None: """Handle the agent calling a function""" # TODO handle 'function' messages that indicate the start of a function call assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata" @@ -466,6 +468,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # and `content` needs to be handled outside the interface expect_reasoning_content: bool = False, name: Optional[str] = None, + chunk_index: int = 0, ) -> Optional[Union[ReasoningMessage, ToolCallMessage, AssistantMessage]]: """ Example data from non-streaming response looks like: @@ -478,6 +481,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): """ choice = chunk.choices[0] message_delta = choice.delta + otid = Message.generate_otid_from_id(message_id, chunk_index) if ( message_delta.content is None @@ -499,6 +503,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): signature=message_delta.reasoning_content_signature, source="reasoner_model" if message_delta.reasoning_content_signature else "non_reasoner_model", name=name, + otid=otid, ) elif expect_reasoning_content and message_delta.redacted_reasoning_content is not None: processed_chunk = HiddenReasoningMessage( @@ -507,6 +512,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): hidden_reasoning=message_delta.redacted_reasoning_content, state="redacted", name=name, + otid=otid, ) elif expect_reasoning_content and message_delta.content is not None: # "ignore" content if we expect reasoning content @@ -534,6 +540,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): tool_call_id=None, ), name=name, + otid=otid, ) except json.JSONDecodeError as e: @@ -564,6 +571,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): date=message_date, reasoning=message_delta.content, name=name, + otid=otid, ) # tool calls @@ -612,7 +620,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # TODO: Assumes consistent state and that prev_content is subset of new_content diff = new_content.replace(prev_content, "", 1) self.current_json_parse_result = parsed_args - processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff, name=name) + processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff, name=name, otid=otid) else: return None @@ -645,6 +653,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): tool_call_id=tool_call_delta.get("id"), ), name=name, + otid=otid, ) elif self.inner_thoughts_in_kwargs and tool_call.function: @@ -681,6 +690,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): date=message_date, reasoning=updates_inner_thoughts, name=name, + otid=otid, ) # Additionally inner thoughts may stream back with a chunk of main JSON # In that case, since we can only return a chunk at a time, we should buffer it @@ -717,6 +727,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): tool_call_id=self.function_id_buffer, ), name=name, + otid=otid, ) # Record what the last function name we flushed was @@ -774,6 +785,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): date=message_date, content=combined_chunk, name=name, + otid=otid, ) # Store the ID of the tool call so allow skipping the corresponding response if self.function_id_buffer: @@ -798,7 +810,9 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # TODO: Assumes consistent state and that prev_content is subset of new_content diff = new_content.replace(prev_content, "", 1) self.current_json_parse_result = parsed_args - processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff, name=name) + processed_chunk = AssistantMessage( + id=message_id, date=message_date, content=diff, name=name, otid=otid + ) else: return None @@ -823,6 +837,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): tool_call_id=self.function_id_buffer, ), name=name, + otid=otid, ) # clear buffer self.function_args_buffer = None @@ -838,6 +853,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): tool_call_id=self.function_id_buffer, ), name=name, + otid=otid, ) self.function_id_buffer = None @@ -967,6 +983,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): tool_call_id=tool_call_delta.get("id"), ), name=name, + otid=otid, ) elif choice.finish_reason is not None: @@ -1048,6 +1065,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): message_date: datetime, expect_reasoning_content: bool = False, name: Optional[str] = None, + chunk_index: int = 0, ): """Process a streaming chunk from an OpenAI-compatible server. @@ -1074,6 +1092,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): message_date=message_date, expect_reasoning_content=expect_reasoning_content, name=name, + chunk_index=chunk_index, ) if processed_chunk is None: @@ -1085,7 +1104,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): """Letta receives a user message""" return - def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None): + def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None): """Letta generates some internal monologue""" if not self.streaming_mode: @@ -1102,6 +1121,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): date=msg_obj.created_at, reasoning=msg, name=msg_obj.name, + otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None, ) self._push_to_buffer(processed_chunk) @@ -1113,6 +1133,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): date=msg_obj.created_at, reasoning=content.text, name=msg_obj.name, + otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None, ) elif isinstance(content, ReasoningContent): processed_chunk = ReasoningMessage( @@ -1122,6 +1143,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): reasoning=content.reasoning, signature=content.signature, name=msg_obj.name, + otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None, ) elif isinstance(content, RedactedReasoningContent): processed_chunk = HiddenReasoningMessage( @@ -1130,6 +1152,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): state="redacted", hidden_reasoning=content.data, name=msg_obj.name, + otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None, ) self._push_to_buffer(processed_chunk) @@ -1142,7 +1165,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # NOTE: this is a no-op, we handle this special case in function_message instead return - def function_message(self, msg: str, msg_obj: Optional[Message] = None): + def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None): """Letta calls a function""" # TODO handle 'function' messages that indicate the start of a function call @@ -1191,6 +1214,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): date=msg_obj.created_at, content=func_args["message"], name=msg_obj.name, + otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None, ) self._push_to_buffer(processed_chunk) except Exception as e: @@ -1214,6 +1238,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): date=msg_obj.created_at, content=func_args[self.assistant_message_tool_kwarg], name=msg_obj.name, + otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None, ) # Store the ID of the tool call so allow skipping the corresponding response self.prev_assistant_message_id = function_call.id @@ -1227,6 +1252,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): tool_call_id=function_call.id, ), name=msg_obj.name, + otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None, ) # processed_chunk = { @@ -1267,6 +1293,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): stdout=msg_obj.tool_returns[0].stdout if msg_obj.tool_returns else None, stderr=msg_obj.tool_returns[0].stderr if msg_obj.tool_returns else None, name=msg_obj.name, + otid=Message.generate_otid_from_id(msg_obj.id, chunk_index), ) elif msg.startswith("Error: "): @@ -1282,6 +1309,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): stdout=msg_obj.tool_returns[0].stdout if msg_obj.tool_returns else None, stderr=msg_obj.tool_returns[0].stderr if msg_obj.tool_returns else None, name=msg_obj.name, + otid=Message.generate_otid_from_id(msg_obj.id, chunk_index), ) else: diff --git a/letta/server/server.py b/letta/server/server.py index 0c11e4c0..83e6cd10 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -892,6 +892,7 @@ class SyncServer(Server): use_assistant_message=use_assistant_message, assistant_message_tool_name=assistant_message_tool_name, assistant_message_tool_kwarg=assistant_message_tool_kwarg, + reverse=reverse, ) if reverse: diff --git a/letta/streaming_interface.py b/letta/streaming_interface.py index 3d007cad..005b104e 100644 --- a/letta/streaming_interface.py +++ b/letta/streaming_interface.py @@ -33,7 +33,7 @@ class AgentChunkStreamingInterface(ABC): raise NotImplementedError @abstractmethod - def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None): + def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None): """Letta generates some internal monologue""" raise NotImplementedError @@ -43,13 +43,18 @@ class AgentChunkStreamingInterface(ABC): raise NotImplementedError @abstractmethod - def function_message(self, msg: str, msg_obj: Optional[Message] = None): + def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None): """Letta calls a function""" raise NotImplementedError @abstractmethod def process_chunk( - self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime, expect_reasoning_content: bool = False + self, + chunk: ChatCompletionChunkResponse, + message_id: str, + message_date: datetime, + expect_reasoning_content: bool = False, + chunk_index: int = 0, ): """Process a streaming chunk from an OpenAI-compatible server""" raise NotImplementedError @@ -166,7 +171,7 @@ class StreamingCLIInterface(AgentChunkStreamingInterface): StreamingCLIInterface.nonstreaming_interface(msg) @staticmethod - def internal_monologue(msg: str, msg_obj: Optional[Message] = None): + def internal_monologue(msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None): StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) @staticmethod @@ -186,7 +191,7 @@ class StreamingCLIInterface(AgentChunkStreamingInterface): StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) @staticmethod - def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG): + def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG, chunk_index: Optional[int] = None): StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) @staticmethod @@ -218,7 +223,7 @@ class AgentRefreshStreamingInterface(ABC): raise NotImplementedError @abstractmethod - def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None): + def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None): """Letta generates some internal monologue""" raise NotImplementedError @@ -228,7 +233,7 @@ class AgentRefreshStreamingInterface(ABC): raise NotImplementedError @abstractmethod - def function_message(self, msg: str, msg_obj: Optional[Message] = None): + def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None): """Letta calls a function""" raise NotImplementedError @@ -355,7 +360,7 @@ class StreamingRefreshCLIInterface(AgentRefreshStreamingInterface): def warning_message(msg: str): StreamingCLIInterface.nonstreaming_interface.warning_message(msg) - def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None): + def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None): if self.disable_inner_mono_call: return StreamingCLIInterface.nonstreaming_interface.internal_monologue(msg, msg_obj) @@ -378,7 +383,7 @@ class StreamingRefreshCLIInterface(AgentRefreshStreamingInterface): StreamingCLIInterface.nonstreaming_interface.user_message(msg, msg_obj) @staticmethod - def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG): + def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG, chunk_index: Optional[int] = None): StreamingCLIInterface.nonstreaming_interface.function_message(msg, msg_obj) @staticmethod diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 18432250..016bac3d 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -6,6 +6,8 @@ import pytest from dotenv import load_dotenv from letta_client import AgentState, Letta, LlmConfig, MessageCreate +from letta.schemas.message import Message + def run_server(): load_dotenv() @@ -73,9 +75,16 @@ def test_streaming_send_message( client.agents.modify(agent_id=agent.id, llm_config=config) # Send streaming message + user_message_otid = Message.generate_otid() response = client.agents.messages.create_stream( agent_id=agent.id, - messages=[MessageCreate(role="user", content="This is a test. Repeat after me: 'banana'")], + messages=[ + MessageCreate( + role="user", + content="This is a test. Repeat after me: 'banana'", + otid=user_message_otid, + ), + ], stream_tokens=stream_tokens, ) @@ -84,6 +93,8 @@ def test_streaming_send_message( inner_thoughts_count = 0 send_message_ran = False done = False + last_message_id = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id + letta_message_otids = [user_message_otid] assert response, "Sending message failed" for chunk in response: @@ -104,6 +115,8 @@ def test_streaming_send_message( assert chunk.prompt_tokens > 1000 assert chunk.total_tokens > 1000 done = True + else: + letta_message_otids.append(chunk.otid) print(chunk) # If stream tokens, we expect at least one inner thought @@ -111,3 +124,6 @@ def test_streaming_send_message( assert inner_thoughts_exist, "No inner thoughts found" assert send_message_ran, "send_message function call not found" assert done, "Message stream not done" + + messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id) + assert [message.otid for message in messages] == letta_message_otids