feat: add otid field for message idempotency (#1556)

This commit is contained in:
cthomas
2025-04-04 08:43:01 -07:00
committed by GitHub
parent 3ba79db859
commit 9458f40d05
17 changed files with 152 additions and 33 deletions

View File

@@ -130,6 +130,7 @@ class Agent(BaseAgent):
# Different interfaces can handle events differently
# e.g., print in CLI vs send a discord message with a discord bot
self.interface = interface
self.chunk_index = 0
# Create the persistence manager object based on the AgentState info
self.message_manager = MessageManager()
@@ -246,9 +247,11 @@ class Agent(BaseAgent):
group_id=group_id,
)
messages.append(new_message)
self.interface.function_message(f"Error: {error_msg}", msg_obj=new_message)
self.interface.function_message(f"Error: {error_msg}", msg_obj=new_message, chunk_index=self.chunk_index)
self.chunk_index += 1
if include_function_failed_message:
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=new_message)
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=new_message, chunk_index=self.chunk_index)
self.chunk_index += 1
# Return updated messages
return messages
@@ -430,7 +433,8 @@ class Agent(BaseAgent):
nonnull_content = False
if response_message.content or response_message.reasoning_content or response_message.redacted_reasoning_content:
# The content if then internal monologue, not chat
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1])
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=self.chunk_index)
self.chunk_index += 1
# Flag to avoid printing a duplicate if inner thoughts get popped from the function call
nonnull_content = True
@@ -479,7 +483,8 @@ class Agent(BaseAgent):
response_message.content = function_args.pop("inner_thoughts")
# The content if then internal monologue, not chat
if response_message.content and not nonnull_content:
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1])
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=self.chunk_index)
self.chunk_index += 1
# (Still parsing function args)
# Handle requests for immediate heartbeat
@@ -501,7 +506,8 @@ class Agent(BaseAgent):
# Failure case 3: function failed during execution
# NOTE: the msg_obj associated with the "Running " message is the prior assistant message, not the function/tool role message
# this is because the function/tool role message is only created once the function/tool has executed/returned
self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1])
self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1], chunk_index=self.chunk_index)
self.chunk_index += 1
try:
# handle tool execution (sandbox) and state updates
log_telemetry(
@@ -634,8 +640,10 @@ class Agent(BaseAgent):
group_id=group_id,
)
) # extend conversation with function response
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1])
self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1])
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1], chunk_index=self.chunk_index)
self.chunk_index += 1
self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1], chunk_index=self.chunk_index)
self.chunk_index += 1
self.last_function_response = function_response
else:
@@ -651,7 +659,8 @@ class Agent(BaseAgent):
group_id=group_id,
)
) # extend conversation with assistant's reply
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1])
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=self.chunk_index)
self.chunk_index += 1
heartbeat_request = False
function_failed = False

View File

@@ -18,13 +18,13 @@ class MultiAgentMessagingInterface(AgentInterface):
self._captured_messages: List[AssistantMessage] = []
self.metadata = {}
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
"""Ignore internal monologue."""
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
"""Ignore normal assistant messages (only capturing send_message calls)."""
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
"""
Called whenever the agent logs a function call. We'll inspect msg_obj.tool_calls:
- If tool_calls include a function named 'send_message', parse its arguments

View File

@@ -192,6 +192,7 @@ class BackgroundMultiAgent(Agent):
tool_calls=None,
tool_call_id=None,
group_id=self.group_id,
otid=message.otid,
)
for message in messages
]

View File

@@ -99,6 +99,7 @@ class DynamicMultiAgent(Agent):
tool_calls=None,
tool_call_id=None,
group_id=self.group_id,
otid=message.otid,
)
)
@@ -125,6 +126,7 @@ class DynamicMultiAgent(Agent):
role="system",
content=message.content,
name=participant_agent.agent_state.name,
otid=message.otid,
)
for message in assistant_messages
]
@@ -271,4 +273,5 @@ class DynamicMultiAgent(Agent):
tool_calls=None,
tool_call_id=None,
group_id=self.group_id,
otid=Message.generate_otid(),
)

View File

@@ -69,6 +69,7 @@ class RoundRobinMultiAgent(Agent):
tool_calls=None,
tool_call_id=None,
group_id=self.group_id,
otid=message.otid,
)
)
@@ -92,6 +93,7 @@ class RoundRobinMultiAgent(Agent):
role="system",
content=message.content,
name=message.name,
otid=message.otid,
)
for message in assistant_messages
]

View File

@@ -89,6 +89,7 @@ class SupervisorMultiAgent(Agent):
tool_calls=None,
tool_call_id=None,
group_id=self.group_id,
otid=message.otid,
)
for message in messages
]

View File

@@ -38,4 +38,5 @@ def prepare_input_message_create(
model=None, # assigned later?
tool_calls=None, # irrelevant
tool_call_id=None,
otid=message.otid,
)

View File

@@ -30,7 +30,7 @@ class AgentInterface(ABC):
raise NotImplementedError
@abstractmethod
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
"""Letta generates some internal monologue"""
raise NotImplementedError
@@ -40,7 +40,7 @@ class AgentInterface(ABC):
raise NotImplementedError
@abstractmethod
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
"""Letta calls a function"""
raise NotImplementedError
@@ -79,7 +79,7 @@ class CLIInterface(AgentInterface):
print(fstr.format(msg=msg))
@staticmethod
def internal_monologue(msg: str, msg_obj: Optional[Message] = None):
def internal_monologue(msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
# ANSI escape code for italic is '\x1B[3m'
fstr = f"\x1B[3m{Fore.LIGHTBLACK_EX}{INNER_THOUGHTS_CLI_SYMBOL} {{msg}}{Style.RESET_ALL}"
if STRIP_UI:
@@ -108,7 +108,14 @@ class CLIInterface(AgentInterface):
print(fstr.format(msg=msg))
@staticmethod
def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG):
def user_message(
msg: str,
msg_obj: Optional[Message] = None,
raw: bool = False,
dump: bool = False,
debug: bool = DEBUG,
chunk_index: Optional[int] = None,
):
def print_user_message(icon, msg, printf=print):
if STRIP_UI:
printf(f"{icon} {msg}")
@@ -154,7 +161,7 @@ class CLIInterface(AgentInterface):
printd_user_message("🧑", msg_json)
@staticmethod
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG):
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG, chunk_index: Optional[int] = None):
def print_function_message(icon, msg, color=Fore.RED, printf=print):
if STRIP_UI:
printf(f"{icon} [function] {msg}")

View File

@@ -953,6 +953,7 @@ def anthropic_chat_completions_process_stream(
# TODO handle emitting redacted reasoning content (e.g. as concat?)
expect_reasoning_content=extended_thinking,
name=name,
chunk_index=chunk_idx,
)
elif isinstance(stream_interface, AgentRefreshStreamingInterface):
stream_interface.process_refresh(chat_completion_response)

View File

@@ -274,6 +274,7 @@ def openai_chat_completions_process_stream(
message_date=chat_completion_response.created if create_message_datetime else chat_completion_chunk.created,
expect_reasoning_content=expect_reasoning_content,
name=name,
chunk_index=chunk_idx,
)
elif isinstance(stream_interface, AgentRefreshStreamingInterface):
stream_interface.process_refresh(chat_completion_response)

View File

@@ -26,11 +26,13 @@ class LettaMessage(BaseModel):
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
name (Optional[str]): The name of the sender of the message
otid (Optional[str]): The offline threading id associated with this message
"""
id: str
date: datetime
name: Optional[str] = None
otid: Optional[str] = None
@field_serializer("date")
def serialize_datetime(self, dt: datetime, _info):

View File

@@ -2,6 +2,7 @@ from __future__ import annotations
import copy
import json
import uuid
import warnings
from collections import OrderedDict
from datetime import datetime, timezone
@@ -78,6 +79,7 @@ class MessageCreate(BaseModel):
json_schema_extra=get_letta_message_content_union_str_json_schema(),
)
name: Optional[str] = Field(None, description="The name of the participant.")
otid: Optional[str] = Field(None, description="The offline threading id associated with this message")
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
data = super().model_dump(**kwargs)
@@ -168,12 +170,17 @@ class Message(BaseMessage):
json_message["created_at"] = self.created_at.isoformat()
return json_message
@staticmethod
def generate_otid():
return str(uuid.uuid4())
@staticmethod
def to_letta_messages_from_list(
messages: List[Message],
use_assistant_message: bool = True,
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
reverse: bool = True,
) -> List[LettaMessage]:
if use_assistant_message:
message_ids_to_remove = []
@@ -203,6 +210,7 @@ class Message(BaseMessage):
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
reverse=reverse,
)
]
@@ -211,6 +219,7 @@ class Message(BaseMessage):
use_assistant_message: bool = False,
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
reverse: bool = True,
) -> List[LettaMessage]:
"""Convert message object (in DB format) to the style used by the original Letta API"""
messages = []
@@ -221,18 +230,21 @@ class Message(BaseMessage):
if self.content:
# Check for ReACT-style COT inside of TextContent
if len(self.content) == 1 and isinstance(self.content[0], TextContent):
otid = Message.generate_otid_from_id(self.id, len(messages))
messages.append(
ReasoningMessage(
id=self.id,
date=self.created_at,
reasoning=self.content[0].text,
name=self.name,
otid=otid,
)
)
# Otherwise, we may have a list of multiple types
else:
# TODO we can probably collapse these two cases into a single loop
for content_part in self.content:
otid = Message.generate_otid_from_id(self.id, len(messages))
if isinstance(content_part, TextContent):
# COT
messages.append(
@@ -241,6 +253,7 @@ class Message(BaseMessage):
date=self.created_at,
reasoning=content_part.text,
name=self.name,
otid=otid,
)
)
elif isinstance(content_part, ReasoningContent):
@@ -253,6 +266,7 @@ class Message(BaseMessage):
source="reasoner_model", # TODO do we want to tag like this?
signature=content_part.signature,
name=self.name,
otid=otid,
)
)
elif isinstance(content_part, RedactedReasoningContent):
@@ -264,6 +278,7 @@ class Message(BaseMessage):
state="redacted",
hidden_reasoning=content_part.data,
name=self.name,
otid=otid,
)
)
else:
@@ -272,6 +287,7 @@ class Message(BaseMessage):
if self.tool_calls is not None:
# This is type FunctionCall
for tool_call in self.tool_calls:
otid = Message.generate_otid_from_id(self.id, len(messages))
# If we're supporting using assistant message,
# then we want to treat certain function calls as a special case
if use_assistant_message and tool_call.function.name == assistant_message_tool_name:
@@ -287,6 +303,7 @@ class Message(BaseMessage):
date=self.created_at,
content=message_string,
name=self.name,
otid=otid,
)
)
else:
@@ -300,6 +317,7 @@ class Message(BaseMessage):
tool_call_id=tool_call.id,
),
name=self.name,
otid=otid,
)
)
elif self.role == MessageRole.tool:
@@ -341,6 +359,7 @@ class Message(BaseMessage):
stdout=self.tool_returns[0].stdout if self.tool_returns else None,
stderr=self.tool_returns[0].stderr if self.tool_returns else None,
name=self.name,
otid=self.id.replace("message-", ""),
)
)
elif self.role == MessageRole.user:
@@ -357,6 +376,7 @@ class Message(BaseMessage):
date=self.created_at,
content=message_str or text_content,
name=self.name,
otid=self.otid,
)
)
elif self.role == MessageRole.system:
@@ -372,11 +392,15 @@ class Message(BaseMessage):
date=self.created_at,
content=text_content,
name=self.name,
otid=self.otid,
)
)
else:
raise ValueError(self.role)
if reverse:
messages.reverse()
return messages
@staticmethod
@@ -991,6 +1015,23 @@ class Message(BaseMessage):
return cohere_message
@staticmethod
def generate_otid_from_id(message_id: str, index: int) -> str:
"""
Convert message id to bits and change the list bit to the index
"""
if not 0 <= index < 128:
raise ValueError("Index must be between 0 and 127")
message_uuid = message_id.replace("message-", "")
uuid_int = int(message_uuid.replace("-", ""), 16)
# Clear last 7 bits and set them to index; supports up to 128 unique indices
uuid_int = (uuid_int & ~0x7F) | (index & 0x7F)
hex_str = f"{uuid_int:032x}"
return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:]}"
class ToolReturn(BaseModel):
status: Literal["success", "error"] = Field(..., description="The status of the tool call")

View File

@@ -172,7 +172,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
"""
return
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None) -> None:
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None) -> None:
"""
Handle LLM reasoning or internal monologue. Example usage: if you want
to capture chain-of-thought for debugging in a non-streaming scenario.
@@ -186,7 +186,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
"""
return
def function_message(self, msg: str, msg_obj: Optional[Message] = None) -> None:
def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None) -> None:
"""
Handle function-related log messages, typically of the form:
It's a no-op by default.

View File

@@ -165,7 +165,7 @@ class QueuingInterface(AgentInterface):
print(vars(msg_obj))
print(msg_obj.created_at.isoformat())
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None) -> None:
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None) -> None:
"""Handle the agent's internal monologue"""
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
if self.debug:
@@ -209,7 +209,9 @@ class QueuingInterface(AgentInterface):
self._queue_push(message_api=new_message, message_obj=msg_obj)
def function_message(self, msg: str, msg_obj: Optional[Message] = None, include_ran_messages: bool = False) -> None:
def function_message(
self, msg: str, msg_obj: Optional[Message] = None, include_ran_messages: bool = False, chunk_index: Optional[int] = None
) -> None:
"""Handle the agent calling a function"""
# TODO handle 'function' messages that indicate the start of a function call
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
@@ -466,6 +468,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# and `content` needs to be handled outside the interface
expect_reasoning_content: bool = False,
name: Optional[str] = None,
chunk_index: int = 0,
) -> Optional[Union[ReasoningMessage, ToolCallMessage, AssistantMessage]]:
"""
Example data from non-streaming response looks like:
@@ -478,6 +481,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
"""
choice = chunk.choices[0]
message_delta = choice.delta
otid = Message.generate_otid_from_id(message_id, chunk_index)
if (
message_delta.content is None
@@ -499,6 +503,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
signature=message_delta.reasoning_content_signature,
source="reasoner_model" if message_delta.reasoning_content_signature else "non_reasoner_model",
name=name,
otid=otid,
)
elif expect_reasoning_content and message_delta.redacted_reasoning_content is not None:
processed_chunk = HiddenReasoningMessage(
@@ -507,6 +512,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
hidden_reasoning=message_delta.redacted_reasoning_content,
state="redacted",
name=name,
otid=otid,
)
elif expect_reasoning_content and message_delta.content is not None:
# "ignore" content if we expect reasoning content
@@ -534,6 +540,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
tool_call_id=None,
),
name=name,
otid=otid,
)
except json.JSONDecodeError as e:
@@ -564,6 +571,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
date=message_date,
reasoning=message_delta.content,
name=name,
otid=otid,
)
# tool calls
@@ -612,7 +620,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# TODO: Assumes consistent state and that prev_content is subset of new_content
diff = new_content.replace(prev_content, "", 1)
self.current_json_parse_result = parsed_args
processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff, name=name)
processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff, name=name, otid=otid)
else:
return None
@@ -645,6 +653,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
tool_call_id=tool_call_delta.get("id"),
),
name=name,
otid=otid,
)
elif self.inner_thoughts_in_kwargs and tool_call.function:
@@ -681,6 +690,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
date=message_date,
reasoning=updates_inner_thoughts,
name=name,
otid=otid,
)
# Additionally inner thoughts may stream back with a chunk of main JSON
# In that case, since we can only return a chunk at a time, we should buffer it
@@ -717,6 +727,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
tool_call_id=self.function_id_buffer,
),
name=name,
otid=otid,
)
# Record what the last function name we flushed was
@@ -774,6 +785,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
date=message_date,
content=combined_chunk,
name=name,
otid=otid,
)
# Store the ID of the tool call so allow skipping the corresponding response
if self.function_id_buffer:
@@ -798,7 +810,9 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# TODO: Assumes consistent state and that prev_content is subset of new_content
diff = new_content.replace(prev_content, "", 1)
self.current_json_parse_result = parsed_args
processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff, name=name)
processed_chunk = AssistantMessage(
id=message_id, date=message_date, content=diff, name=name, otid=otid
)
else:
return None
@@ -823,6 +837,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
tool_call_id=self.function_id_buffer,
),
name=name,
otid=otid,
)
# clear buffer
self.function_args_buffer = None
@@ -838,6 +853,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
tool_call_id=self.function_id_buffer,
),
name=name,
otid=otid,
)
self.function_id_buffer = None
@@ -967,6 +983,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
tool_call_id=tool_call_delta.get("id"),
),
name=name,
otid=otid,
)
elif choice.finish_reason is not None:
@@ -1048,6 +1065,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
message_date: datetime,
expect_reasoning_content: bool = False,
name: Optional[str] = None,
chunk_index: int = 0,
):
"""Process a streaming chunk from an OpenAI-compatible server.
@@ -1074,6 +1092,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
message_date=message_date,
expect_reasoning_content=expect_reasoning_content,
name=name,
chunk_index=chunk_index,
)
if processed_chunk is None:
@@ -1085,7 +1104,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
"""Letta receives a user message"""
return
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
"""Letta generates some internal monologue"""
if not self.streaming_mode:
@@ -1102,6 +1121,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
date=msg_obj.created_at,
reasoning=msg,
name=msg_obj.name,
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
)
self._push_to_buffer(processed_chunk)
@@ -1113,6 +1133,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
date=msg_obj.created_at,
reasoning=content.text,
name=msg_obj.name,
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
)
elif isinstance(content, ReasoningContent):
processed_chunk = ReasoningMessage(
@@ -1122,6 +1143,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
reasoning=content.reasoning,
signature=content.signature,
name=msg_obj.name,
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
)
elif isinstance(content, RedactedReasoningContent):
processed_chunk = HiddenReasoningMessage(
@@ -1130,6 +1152,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
state="redacted",
hidden_reasoning=content.data,
name=msg_obj.name,
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
)
self._push_to_buffer(processed_chunk)
@@ -1142,7 +1165,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# NOTE: this is a no-op, we handle this special case in function_message instead
return
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
"""Letta calls a function"""
# TODO handle 'function' messages that indicate the start of a function call
@@ -1191,6 +1214,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
date=msg_obj.created_at,
content=func_args["message"],
name=msg_obj.name,
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
)
self._push_to_buffer(processed_chunk)
except Exception as e:
@@ -1214,6 +1238,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
date=msg_obj.created_at,
content=func_args[self.assistant_message_tool_kwarg],
name=msg_obj.name,
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
)
# Store the ID of the tool call so allow skipping the corresponding response
self.prev_assistant_message_id = function_call.id
@@ -1227,6 +1252,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
tool_call_id=function_call.id,
),
name=msg_obj.name,
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
)
# processed_chunk = {
@@ -1267,6 +1293,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
stdout=msg_obj.tool_returns[0].stdout if msg_obj.tool_returns else None,
stderr=msg_obj.tool_returns[0].stderr if msg_obj.tool_returns else None,
name=msg_obj.name,
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index),
)
elif msg.startswith("Error: "):
@@ -1282,6 +1309,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
stdout=msg_obj.tool_returns[0].stdout if msg_obj.tool_returns else None,
stderr=msg_obj.tool_returns[0].stderr if msg_obj.tool_returns else None,
name=msg_obj.name,
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index),
)
else:

View File

@@ -892,6 +892,7 @@ class SyncServer(Server):
use_assistant_message=use_assistant_message,
assistant_message_tool_name=assistant_message_tool_name,
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
reverse=reverse,
)
if reverse:

View File

@@ -33,7 +33,7 @@ class AgentChunkStreamingInterface(ABC):
raise NotImplementedError
@abstractmethod
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
"""Letta generates some internal monologue"""
raise NotImplementedError
@@ -43,13 +43,18 @@ class AgentChunkStreamingInterface(ABC):
raise NotImplementedError
@abstractmethod
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
"""Letta calls a function"""
raise NotImplementedError
@abstractmethod
def process_chunk(
self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime, expect_reasoning_content: bool = False
self,
chunk: ChatCompletionChunkResponse,
message_id: str,
message_date: datetime,
expect_reasoning_content: bool = False,
chunk_index: int = 0,
):
"""Process a streaming chunk from an OpenAI-compatible server"""
raise NotImplementedError
@@ -166,7 +171,7 @@ class StreamingCLIInterface(AgentChunkStreamingInterface):
StreamingCLIInterface.nonstreaming_interface(msg)
@staticmethod
def internal_monologue(msg: str, msg_obj: Optional[Message] = None):
def internal_monologue(msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
@staticmethod
@@ -186,7 +191,7 @@ class StreamingCLIInterface(AgentChunkStreamingInterface):
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
@staticmethod
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG):
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG, chunk_index: Optional[int] = None):
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
@staticmethod
@@ -218,7 +223,7 @@ class AgentRefreshStreamingInterface(ABC):
raise NotImplementedError
@abstractmethod
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
"""Letta generates some internal monologue"""
raise NotImplementedError
@@ -228,7 +233,7 @@ class AgentRefreshStreamingInterface(ABC):
raise NotImplementedError
@abstractmethod
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
"""Letta calls a function"""
raise NotImplementedError
@@ -355,7 +360,7 @@ class StreamingRefreshCLIInterface(AgentRefreshStreamingInterface):
def warning_message(msg: str):
StreamingCLIInterface.nonstreaming_interface.warning_message(msg)
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
if self.disable_inner_mono_call:
return
StreamingCLIInterface.nonstreaming_interface.internal_monologue(msg, msg_obj)
@@ -378,7 +383,7 @@ class StreamingRefreshCLIInterface(AgentRefreshStreamingInterface):
StreamingCLIInterface.nonstreaming_interface.user_message(msg, msg_obj)
@staticmethod
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG):
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG, chunk_index: Optional[int] = None):
StreamingCLIInterface.nonstreaming_interface.function_message(msg, msg_obj)
@staticmethod

View File

@@ -6,6 +6,8 @@ import pytest
from dotenv import load_dotenv
from letta_client import AgentState, Letta, LlmConfig, MessageCreate
from letta.schemas.message import Message
def run_server():
load_dotenv()
@@ -73,9 +75,16 @@ def test_streaming_send_message(
client.agents.modify(agent_id=agent.id, llm_config=config)
# Send streaming message
user_message_otid = Message.generate_otid()
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=[MessageCreate(role="user", content="This is a test. Repeat after me: 'banana'")],
messages=[
MessageCreate(
role="user",
content="This is a test. Repeat after me: 'banana'",
otid=user_message_otid,
),
],
stream_tokens=stream_tokens,
)
@@ -84,6 +93,8 @@ def test_streaming_send_message(
inner_thoughts_count = 0
send_message_ran = False
done = False
last_message_id = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id
letta_message_otids = [user_message_otid]
assert response, "Sending message failed"
for chunk in response:
@@ -104,6 +115,8 @@ def test_streaming_send_message(
assert chunk.prompt_tokens > 1000
assert chunk.total_tokens > 1000
done = True
else:
letta_message_otids.append(chunk.otid)
print(chunk)
# If stream tokens, we expect at least one inner thought
@@ -111,3 +124,6 @@ def test_streaming_send_message(
assert inner_thoughts_exist, "No inner thoughts found"
assert send_message_ran, "send_message function call not found"
assert done, "Message stream not done"
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id)
assert [message.otid for message in messages] == letta_message_otids