feat: add otid field for message idempotency (#1556)
This commit is contained in:
@@ -130,6 +130,7 @@ class Agent(BaseAgent):
|
||||
# Different interfaces can handle events differently
|
||||
# e.g., print in CLI vs send a discord message with a discord bot
|
||||
self.interface = interface
|
||||
self.chunk_index = 0
|
||||
|
||||
# Create the persistence manager object based on the AgentState info
|
||||
self.message_manager = MessageManager()
|
||||
@@ -246,9 +247,11 @@ class Agent(BaseAgent):
|
||||
group_id=group_id,
|
||||
)
|
||||
messages.append(new_message)
|
||||
self.interface.function_message(f"Error: {error_msg}", msg_obj=new_message)
|
||||
self.interface.function_message(f"Error: {error_msg}", msg_obj=new_message, chunk_index=self.chunk_index)
|
||||
self.chunk_index += 1
|
||||
if include_function_failed_message:
|
||||
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=new_message)
|
||||
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=new_message, chunk_index=self.chunk_index)
|
||||
self.chunk_index += 1
|
||||
|
||||
# Return updated messages
|
||||
return messages
|
||||
@@ -430,7 +433,8 @@ class Agent(BaseAgent):
|
||||
nonnull_content = False
|
||||
if response_message.content or response_message.reasoning_content or response_message.redacted_reasoning_content:
|
||||
# The content if then internal monologue, not chat
|
||||
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1])
|
||||
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=self.chunk_index)
|
||||
self.chunk_index += 1
|
||||
# Flag to avoid printing a duplicate if inner thoughts get popped from the function call
|
||||
nonnull_content = True
|
||||
|
||||
@@ -479,7 +483,8 @@ class Agent(BaseAgent):
|
||||
response_message.content = function_args.pop("inner_thoughts")
|
||||
# The content if then internal monologue, not chat
|
||||
if response_message.content and not nonnull_content:
|
||||
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1])
|
||||
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=self.chunk_index)
|
||||
self.chunk_index += 1
|
||||
|
||||
# (Still parsing function args)
|
||||
# Handle requests for immediate heartbeat
|
||||
@@ -501,7 +506,8 @@ class Agent(BaseAgent):
|
||||
# Failure case 3: function failed during execution
|
||||
# NOTE: the msg_obj associated with the "Running " message is the prior assistant message, not the function/tool role message
|
||||
# this is because the function/tool role message is only created once the function/tool has executed/returned
|
||||
self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1])
|
||||
self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1], chunk_index=self.chunk_index)
|
||||
self.chunk_index += 1
|
||||
try:
|
||||
# handle tool execution (sandbox) and state updates
|
||||
log_telemetry(
|
||||
@@ -634,8 +640,10 @@ class Agent(BaseAgent):
|
||||
group_id=group_id,
|
||||
)
|
||||
) # extend conversation with function response
|
||||
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1])
|
||||
self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1])
|
||||
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1], chunk_index=self.chunk_index)
|
||||
self.chunk_index += 1
|
||||
self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1], chunk_index=self.chunk_index)
|
||||
self.chunk_index += 1
|
||||
self.last_function_response = function_response
|
||||
|
||||
else:
|
||||
@@ -651,7 +659,8 @@ class Agent(BaseAgent):
|
||||
group_id=group_id,
|
||||
)
|
||||
) # extend conversation with assistant's reply
|
||||
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1])
|
||||
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1], chunk_index=self.chunk_index)
|
||||
self.chunk_index += 1
|
||||
heartbeat_request = False
|
||||
function_failed = False
|
||||
|
||||
|
||||
@@ -18,13 +18,13 @@ class MultiAgentMessagingInterface(AgentInterface):
|
||||
self._captured_messages: List[AssistantMessage] = []
|
||||
self.metadata = {}
|
||||
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
|
||||
"""Ignore internal monologue."""
|
||||
|
||||
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""Ignore normal assistant messages (only capturing send_message calls)."""
|
||||
|
||||
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
|
||||
"""
|
||||
Called whenever the agent logs a function call. We'll inspect msg_obj.tool_calls:
|
||||
- If tool_calls include a function named 'send_message', parse its arguments
|
||||
|
||||
@@ -192,6 +192,7 @@ class BackgroundMultiAgent(Agent):
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
otid=message.otid,
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
|
||||
@@ -99,6 +99,7 @@ class DynamicMultiAgent(Agent):
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
otid=message.otid,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -125,6 +126,7 @@ class DynamicMultiAgent(Agent):
|
||||
role="system",
|
||||
content=message.content,
|
||||
name=participant_agent.agent_state.name,
|
||||
otid=message.otid,
|
||||
)
|
||||
for message in assistant_messages
|
||||
]
|
||||
@@ -271,4 +273,5 @@ class DynamicMultiAgent(Agent):
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
otid=Message.generate_otid(),
|
||||
)
|
||||
|
||||
@@ -69,6 +69,7 @@ class RoundRobinMultiAgent(Agent):
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
otid=message.otid,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -92,6 +93,7 @@ class RoundRobinMultiAgent(Agent):
|
||||
role="system",
|
||||
content=message.content,
|
||||
name=message.name,
|
||||
otid=message.otid,
|
||||
)
|
||||
for message in assistant_messages
|
||||
]
|
||||
|
||||
@@ -89,6 +89,7 @@ class SupervisorMultiAgent(Agent):
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
otid=message.otid,
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
|
||||
@@ -38,4 +38,5 @@ def prepare_input_message_create(
|
||||
model=None, # assigned later?
|
||||
tool_calls=None, # irrelevant
|
||||
tool_call_id=None,
|
||||
otid=message.otid,
|
||||
)
|
||||
|
||||
@@ -30,7 +30,7 @@ class AgentInterface(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
|
||||
"""Letta generates some internal monologue"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -40,7 +40,7 @@ class AgentInterface(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
|
||||
"""Letta calls a function"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -79,7 +79,7 @@ class CLIInterface(AgentInterface):
|
||||
print(fstr.format(msg=msg))
|
||||
|
||||
@staticmethod
|
||||
def internal_monologue(msg: str, msg_obj: Optional[Message] = None):
|
||||
def internal_monologue(msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
|
||||
# ANSI escape code for italic is '\x1B[3m'
|
||||
fstr = f"\x1B[3m{Fore.LIGHTBLACK_EX}{INNER_THOUGHTS_CLI_SYMBOL} {{msg}}{Style.RESET_ALL}"
|
||||
if STRIP_UI:
|
||||
@@ -108,7 +108,14 @@ class CLIInterface(AgentInterface):
|
||||
print(fstr.format(msg=msg))
|
||||
|
||||
@staticmethod
|
||||
def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG):
|
||||
def user_message(
|
||||
msg: str,
|
||||
msg_obj: Optional[Message] = None,
|
||||
raw: bool = False,
|
||||
dump: bool = False,
|
||||
debug: bool = DEBUG,
|
||||
chunk_index: Optional[int] = None,
|
||||
):
|
||||
def print_user_message(icon, msg, printf=print):
|
||||
if STRIP_UI:
|
||||
printf(f"{icon} {msg}")
|
||||
@@ -154,7 +161,7 @@ class CLIInterface(AgentInterface):
|
||||
printd_user_message("🧑", msg_json)
|
||||
|
||||
@staticmethod
|
||||
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG):
|
||||
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG, chunk_index: Optional[int] = None):
|
||||
def print_function_message(icon, msg, color=Fore.RED, printf=print):
|
||||
if STRIP_UI:
|
||||
printf(f"⚡{icon} [function] {msg}")
|
||||
|
||||
@@ -953,6 +953,7 @@ def anthropic_chat_completions_process_stream(
|
||||
# TODO handle emitting redacted reasoning content (e.g. as concat?)
|
||||
expect_reasoning_content=extended_thinking,
|
||||
name=name,
|
||||
chunk_index=chunk_idx,
|
||||
)
|
||||
elif isinstance(stream_interface, AgentRefreshStreamingInterface):
|
||||
stream_interface.process_refresh(chat_completion_response)
|
||||
|
||||
@@ -274,6 +274,7 @@ def openai_chat_completions_process_stream(
|
||||
message_date=chat_completion_response.created if create_message_datetime else chat_completion_chunk.created,
|
||||
expect_reasoning_content=expect_reasoning_content,
|
||||
name=name,
|
||||
chunk_index=chunk_idx,
|
||||
)
|
||||
elif isinstance(stream_interface, AgentRefreshStreamingInterface):
|
||||
stream_interface.process_refresh(chat_completion_response)
|
||||
|
||||
@@ -26,11 +26,13 @@ class LettaMessage(BaseModel):
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
name (Optional[str]): The name of the sender of the message
|
||||
otid (Optional[str]): The offline threading id associated with this message
|
||||
"""
|
||||
|
||||
id: str
|
||||
date: datetime
|
||||
name: Optional[str] = None
|
||||
otid: Optional[str] = None
|
||||
|
||||
@field_serializer("date")
|
||||
def serialize_datetime(self, dt: datetime, _info):
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timezone
|
||||
@@ -78,6 +79,7 @@ class MessageCreate(BaseModel):
|
||||
json_schema_extra=get_letta_message_content_union_str_json_schema(),
|
||||
)
|
||||
name: Optional[str] = Field(None, description="The name of the participant.")
|
||||
otid: Optional[str] = Field(None, description="The offline threading id associated with this message")
|
||||
|
||||
def model_dump(self, to_orm: bool = False, **kwargs) -> Dict[str, Any]:
|
||||
data = super().model_dump(**kwargs)
|
||||
@@ -168,12 +170,17 @@ class Message(BaseMessage):
|
||||
json_message["created_at"] = self.created_at.isoformat()
|
||||
return json_message
|
||||
|
||||
@staticmethod
|
||||
def generate_otid():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
@staticmethod
|
||||
def to_letta_messages_from_list(
|
||||
messages: List[Message],
|
||||
use_assistant_message: bool = True,
|
||||
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
reverse: bool = True,
|
||||
) -> List[LettaMessage]:
|
||||
if use_assistant_message:
|
||||
message_ids_to_remove = []
|
||||
@@ -203,6 +210,7 @@ class Message(BaseMessage):
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
reverse=reverse,
|
||||
)
|
||||
]
|
||||
|
||||
@@ -211,6 +219,7 @@ class Message(BaseMessage):
|
||||
use_assistant_message: bool = False,
|
||||
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
reverse: bool = True,
|
||||
) -> List[LettaMessage]:
|
||||
"""Convert message object (in DB format) to the style used by the original Letta API"""
|
||||
messages = []
|
||||
@@ -221,18 +230,21 @@ class Message(BaseMessage):
|
||||
if self.content:
|
||||
# Check for ReACT-style COT inside of TextContent
|
||||
if len(self.content) == 1 and isinstance(self.content[0], TextContent):
|
||||
otid = Message.generate_otid_from_id(self.id, len(messages))
|
||||
messages.append(
|
||||
ReasoningMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
reasoning=self.content[0].text,
|
||||
name=self.name,
|
||||
otid=otid,
|
||||
)
|
||||
)
|
||||
# Otherwise, we may have a list of multiple types
|
||||
else:
|
||||
# TODO we can probably collapse these two cases into a single loop
|
||||
for content_part in self.content:
|
||||
otid = Message.generate_otid_from_id(self.id, len(messages))
|
||||
if isinstance(content_part, TextContent):
|
||||
# COT
|
||||
messages.append(
|
||||
@@ -241,6 +253,7 @@ class Message(BaseMessage):
|
||||
date=self.created_at,
|
||||
reasoning=content_part.text,
|
||||
name=self.name,
|
||||
otid=otid,
|
||||
)
|
||||
)
|
||||
elif isinstance(content_part, ReasoningContent):
|
||||
@@ -253,6 +266,7 @@ class Message(BaseMessage):
|
||||
source="reasoner_model", # TODO do we want to tag like this?
|
||||
signature=content_part.signature,
|
||||
name=self.name,
|
||||
otid=otid,
|
||||
)
|
||||
)
|
||||
elif isinstance(content_part, RedactedReasoningContent):
|
||||
@@ -264,6 +278,7 @@ class Message(BaseMessage):
|
||||
state="redacted",
|
||||
hidden_reasoning=content_part.data,
|
||||
name=self.name,
|
||||
otid=otid,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -272,6 +287,7 @@ class Message(BaseMessage):
|
||||
if self.tool_calls is not None:
|
||||
# This is type FunctionCall
|
||||
for tool_call in self.tool_calls:
|
||||
otid = Message.generate_otid_from_id(self.id, len(messages))
|
||||
# If we're supporting using assistant message,
|
||||
# then we want to treat certain function calls as a special case
|
||||
if use_assistant_message and tool_call.function.name == assistant_message_tool_name:
|
||||
@@ -287,6 +303,7 @@ class Message(BaseMessage):
|
||||
date=self.created_at,
|
||||
content=message_string,
|
||||
name=self.name,
|
||||
otid=otid,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -300,6 +317,7 @@ class Message(BaseMessage):
|
||||
tool_call_id=tool_call.id,
|
||||
),
|
||||
name=self.name,
|
||||
otid=otid,
|
||||
)
|
||||
)
|
||||
elif self.role == MessageRole.tool:
|
||||
@@ -341,6 +359,7 @@ class Message(BaseMessage):
|
||||
stdout=self.tool_returns[0].stdout if self.tool_returns else None,
|
||||
stderr=self.tool_returns[0].stderr if self.tool_returns else None,
|
||||
name=self.name,
|
||||
otid=self.id.replace("message-", ""),
|
||||
)
|
||||
)
|
||||
elif self.role == MessageRole.user:
|
||||
@@ -357,6 +376,7 @@ class Message(BaseMessage):
|
||||
date=self.created_at,
|
||||
content=message_str or text_content,
|
||||
name=self.name,
|
||||
otid=self.otid,
|
||||
)
|
||||
)
|
||||
elif self.role == MessageRole.system:
|
||||
@@ -372,11 +392,15 @@ class Message(BaseMessage):
|
||||
date=self.created_at,
|
||||
content=text_content,
|
||||
name=self.name,
|
||||
otid=self.otid,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(self.role)
|
||||
|
||||
if reverse:
|
||||
messages.reverse()
|
||||
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
@@ -991,6 +1015,23 @@ class Message(BaseMessage):
|
||||
|
||||
return cohere_message
|
||||
|
||||
@staticmethod
|
||||
def generate_otid_from_id(message_id: str, index: int) -> str:
|
||||
"""
|
||||
Convert message id to bits and change the list bit to the index
|
||||
"""
|
||||
if not 0 <= index < 128:
|
||||
raise ValueError("Index must be between 0 and 127")
|
||||
|
||||
message_uuid = message_id.replace("message-", "")
|
||||
uuid_int = int(message_uuid.replace("-", ""), 16)
|
||||
|
||||
# Clear last 7 bits and set them to index; supports up to 128 unique indices
|
||||
uuid_int = (uuid_int & ~0x7F) | (index & 0x7F)
|
||||
|
||||
hex_str = f"{uuid_int:032x}"
|
||||
return f"{hex_str[:8]}-{hex_str[8:12]}-{hex_str[12:16]}-{hex_str[16:20]}-{hex_str[20:]}"
|
||||
|
||||
|
||||
class ToolReturn(BaseModel):
|
||||
status: Literal["success", "error"] = Field(..., description="The status of the tool call")
|
||||
|
||||
@@ -172,7 +172,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
|
||||
"""
|
||||
return
|
||||
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None) -> None:
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None) -> None:
|
||||
"""
|
||||
Handle LLM reasoning or internal monologue. Example usage: if you want
|
||||
to capture chain-of-thought for debugging in a non-streaming scenario.
|
||||
@@ -186,7 +186,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
|
||||
"""
|
||||
return
|
||||
|
||||
def function_message(self, msg: str, msg_obj: Optional[Message] = None) -> None:
|
||||
def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None) -> None:
|
||||
"""
|
||||
Handle function-related log messages, typically of the form:
|
||||
It's a no-op by default.
|
||||
|
||||
@@ -165,7 +165,7 @@ class QueuingInterface(AgentInterface):
|
||||
print(vars(msg_obj))
|
||||
print(msg_obj.created_at.isoformat())
|
||||
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None) -> None:
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None) -> None:
|
||||
"""Handle the agent's internal monologue"""
|
||||
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
|
||||
if self.debug:
|
||||
@@ -209,7 +209,9 @@ class QueuingInterface(AgentInterface):
|
||||
|
||||
self._queue_push(message_api=new_message, message_obj=msg_obj)
|
||||
|
||||
def function_message(self, msg: str, msg_obj: Optional[Message] = None, include_ran_messages: bool = False) -> None:
|
||||
def function_message(
|
||||
self, msg: str, msg_obj: Optional[Message] = None, include_ran_messages: bool = False, chunk_index: Optional[int] = None
|
||||
) -> None:
|
||||
"""Handle the agent calling a function"""
|
||||
# TODO handle 'function' messages that indicate the start of a function call
|
||||
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
|
||||
@@ -466,6 +468,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# and `content` needs to be handled outside the interface
|
||||
expect_reasoning_content: bool = False,
|
||||
name: Optional[str] = None,
|
||||
chunk_index: int = 0,
|
||||
) -> Optional[Union[ReasoningMessage, ToolCallMessage, AssistantMessage]]:
|
||||
"""
|
||||
Example data from non-streaming response looks like:
|
||||
@@ -478,6 +481,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
"""
|
||||
choice = chunk.choices[0]
|
||||
message_delta = choice.delta
|
||||
otid = Message.generate_otid_from_id(message_id, chunk_index)
|
||||
|
||||
if (
|
||||
message_delta.content is None
|
||||
@@ -499,6 +503,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
signature=message_delta.reasoning_content_signature,
|
||||
source="reasoner_model" if message_delta.reasoning_content_signature else "non_reasoner_model",
|
||||
name=name,
|
||||
otid=otid,
|
||||
)
|
||||
elif expect_reasoning_content and message_delta.redacted_reasoning_content is not None:
|
||||
processed_chunk = HiddenReasoningMessage(
|
||||
@@ -507,6 +512,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
hidden_reasoning=message_delta.redacted_reasoning_content,
|
||||
state="redacted",
|
||||
name=name,
|
||||
otid=otid,
|
||||
)
|
||||
elif expect_reasoning_content and message_delta.content is not None:
|
||||
# "ignore" content if we expect reasoning content
|
||||
@@ -534,6 +540,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
tool_call_id=None,
|
||||
),
|
||||
name=name,
|
||||
otid=otid,
|
||||
)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
@@ -564,6 +571,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
date=message_date,
|
||||
reasoning=message_delta.content,
|
||||
name=name,
|
||||
otid=otid,
|
||||
)
|
||||
|
||||
# tool calls
|
||||
@@ -612,7 +620,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# TODO: Assumes consistent state and that prev_content is subset of new_content
|
||||
diff = new_content.replace(prev_content, "", 1)
|
||||
self.current_json_parse_result = parsed_args
|
||||
processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff, name=name)
|
||||
processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff, name=name, otid=otid)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -645,6 +653,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
tool_call_id=tool_call_delta.get("id"),
|
||||
),
|
||||
name=name,
|
||||
otid=otid,
|
||||
)
|
||||
|
||||
elif self.inner_thoughts_in_kwargs and tool_call.function:
|
||||
@@ -681,6 +690,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
date=message_date,
|
||||
reasoning=updates_inner_thoughts,
|
||||
name=name,
|
||||
otid=otid,
|
||||
)
|
||||
# Additionally inner thoughts may stream back with a chunk of main JSON
|
||||
# In that case, since we can only return a chunk at a time, we should buffer it
|
||||
@@ -717,6 +727,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
name=name,
|
||||
otid=otid,
|
||||
)
|
||||
|
||||
# Record what the last function name we flushed was
|
||||
@@ -774,6 +785,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
date=message_date,
|
||||
content=combined_chunk,
|
||||
name=name,
|
||||
otid=otid,
|
||||
)
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
if self.function_id_buffer:
|
||||
@@ -798,7 +810,9 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# TODO: Assumes consistent state and that prev_content is subset of new_content
|
||||
diff = new_content.replace(prev_content, "", 1)
|
||||
self.current_json_parse_result = parsed_args
|
||||
processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff, name=name)
|
||||
processed_chunk = AssistantMessage(
|
||||
id=message_id, date=message_date, content=diff, name=name, otid=otid
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -823,6 +837,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
name=name,
|
||||
otid=otid,
|
||||
)
|
||||
# clear buffer
|
||||
self.function_args_buffer = None
|
||||
@@ -838,6 +853,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
name=name,
|
||||
otid=otid,
|
||||
)
|
||||
self.function_id_buffer = None
|
||||
|
||||
@@ -967,6 +983,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
tool_call_id=tool_call_delta.get("id"),
|
||||
),
|
||||
name=name,
|
||||
otid=otid,
|
||||
)
|
||||
|
||||
elif choice.finish_reason is not None:
|
||||
@@ -1048,6 +1065,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
message_date: datetime,
|
||||
expect_reasoning_content: bool = False,
|
||||
name: Optional[str] = None,
|
||||
chunk_index: int = 0,
|
||||
):
|
||||
"""Process a streaming chunk from an OpenAI-compatible server.
|
||||
|
||||
@@ -1074,6 +1092,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
message_date=message_date,
|
||||
expect_reasoning_content=expect_reasoning_content,
|
||||
name=name,
|
||||
chunk_index=chunk_index,
|
||||
)
|
||||
|
||||
if processed_chunk is None:
|
||||
@@ -1085,7 +1104,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
"""Letta receives a user message"""
|
||||
return
|
||||
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
|
||||
"""Letta generates some internal monologue"""
|
||||
if not self.streaming_mode:
|
||||
|
||||
@@ -1102,6 +1121,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
date=msg_obj.created_at,
|
||||
reasoning=msg,
|
||||
name=msg_obj.name,
|
||||
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
|
||||
)
|
||||
|
||||
self._push_to_buffer(processed_chunk)
|
||||
@@ -1113,6 +1133,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
date=msg_obj.created_at,
|
||||
reasoning=content.text,
|
||||
name=msg_obj.name,
|
||||
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
|
||||
)
|
||||
elif isinstance(content, ReasoningContent):
|
||||
processed_chunk = ReasoningMessage(
|
||||
@@ -1122,6 +1143,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
reasoning=content.reasoning,
|
||||
signature=content.signature,
|
||||
name=msg_obj.name,
|
||||
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
|
||||
)
|
||||
elif isinstance(content, RedactedReasoningContent):
|
||||
processed_chunk = HiddenReasoningMessage(
|
||||
@@ -1130,6 +1152,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
state="redacted",
|
||||
hidden_reasoning=content.data,
|
||||
name=msg_obj.name,
|
||||
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
|
||||
)
|
||||
|
||||
self._push_to_buffer(processed_chunk)
|
||||
@@ -1142,7 +1165,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# NOTE: this is a no-op, we handle this special case in function_message instead
|
||||
return
|
||||
|
||||
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
|
||||
"""Letta calls a function"""
|
||||
|
||||
# TODO handle 'function' messages that indicate the start of a function call
|
||||
@@ -1191,6 +1214,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
date=msg_obj.created_at,
|
||||
content=func_args["message"],
|
||||
name=msg_obj.name,
|
||||
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
|
||||
)
|
||||
self._push_to_buffer(processed_chunk)
|
||||
except Exception as e:
|
||||
@@ -1214,6 +1238,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
date=msg_obj.created_at,
|
||||
content=func_args[self.assistant_message_tool_kwarg],
|
||||
name=msg_obj.name,
|
||||
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
|
||||
)
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
self.prev_assistant_message_id = function_call.id
|
||||
@@ -1227,6 +1252,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
tool_call_id=function_call.id,
|
||||
),
|
||||
name=msg_obj.name,
|
||||
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None,
|
||||
)
|
||||
|
||||
# processed_chunk = {
|
||||
@@ -1267,6 +1293,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
stdout=msg_obj.tool_returns[0].stdout if msg_obj.tool_returns else None,
|
||||
stderr=msg_obj.tool_returns[0].stderr if msg_obj.tool_returns else None,
|
||||
name=msg_obj.name,
|
||||
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index),
|
||||
)
|
||||
|
||||
elif msg.startswith("Error: "):
|
||||
@@ -1282,6 +1309,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
stdout=msg_obj.tool_returns[0].stdout if msg_obj.tool_returns else None,
|
||||
stderr=msg_obj.tool_returns[0].stderr if msg_obj.tool_returns else None,
|
||||
name=msg_obj.name,
|
||||
otid=Message.generate_otid_from_id(msg_obj.id, chunk_index),
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -892,6 +892,7 @@ class SyncServer(Server):
|
||||
use_assistant_message=use_assistant_message,
|
||||
assistant_message_tool_name=assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=assistant_message_tool_kwarg,
|
||||
reverse=reverse,
|
||||
)
|
||||
|
||||
if reverse:
|
||||
|
||||
@@ -33,7 +33,7 @@ class AgentChunkStreamingInterface(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
|
||||
"""Letta generates some internal monologue"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -43,13 +43,18 @@ class AgentChunkStreamingInterface(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
|
||||
"""Letta calls a function"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process_chunk(
|
||||
self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime, expect_reasoning_content: bool = False
|
||||
self,
|
||||
chunk: ChatCompletionChunkResponse,
|
||||
message_id: str,
|
||||
message_date: datetime,
|
||||
expect_reasoning_content: bool = False,
|
||||
chunk_index: int = 0,
|
||||
):
|
||||
"""Process a streaming chunk from an OpenAI-compatible server"""
|
||||
raise NotImplementedError
|
||||
@@ -166,7 +171,7 @@ class StreamingCLIInterface(AgentChunkStreamingInterface):
|
||||
StreamingCLIInterface.nonstreaming_interface(msg)
|
||||
|
||||
@staticmethod
|
||||
def internal_monologue(msg: str, msg_obj: Optional[Message] = None):
|
||||
def internal_monologue(msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
|
||||
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
@@ -186,7 +191,7 @@ class StreamingCLIInterface(AgentChunkStreamingInterface):
|
||||
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG):
|
||||
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG, chunk_index: Optional[int] = None):
|
||||
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
@@ -218,7 +223,7 @@ class AgentRefreshStreamingInterface(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
|
||||
"""Letta generates some internal monologue"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -228,7 +233,7 @@ class AgentRefreshStreamingInterface(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
def function_message(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
|
||||
"""Letta calls a function"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -355,7 +360,7 @@ class StreamingRefreshCLIInterface(AgentRefreshStreamingInterface):
|
||||
def warning_message(msg: str):
|
||||
StreamingCLIInterface.nonstreaming_interface.warning_message(msg)
|
||||
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None, chunk_index: Optional[int] = None):
|
||||
if self.disable_inner_mono_call:
|
||||
return
|
||||
StreamingCLIInterface.nonstreaming_interface.internal_monologue(msg, msg_obj)
|
||||
@@ -378,7 +383,7 @@ class StreamingRefreshCLIInterface(AgentRefreshStreamingInterface):
|
||||
StreamingCLIInterface.nonstreaming_interface.user_message(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG):
|
||||
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG, chunk_index: Optional[int] = None):
|
||||
StreamingCLIInterface.nonstreaming_interface.function_message(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -6,6 +6,8 @@ import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import AgentState, Letta, LlmConfig, MessageCreate
|
||||
|
||||
from letta.schemas.message import Message
|
||||
|
||||
|
||||
def run_server():
|
||||
load_dotenv()
|
||||
@@ -73,9 +75,16 @@ def test_streaming_send_message(
|
||||
client.agents.modify(agent_id=agent.id, llm_config=config)
|
||||
|
||||
# Send streaming message
|
||||
user_message_otid = Message.generate_otid()
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent.id,
|
||||
messages=[MessageCreate(role="user", content="This is a test. Repeat after me: 'banana'")],
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="This is a test. Repeat after me: 'banana'",
|
||||
otid=user_message_otid,
|
||||
),
|
||||
],
|
||||
stream_tokens=stream_tokens,
|
||||
)
|
||||
|
||||
@@ -84,6 +93,8 @@ def test_streaming_send_message(
|
||||
inner_thoughts_count = 0
|
||||
send_message_ran = False
|
||||
done = False
|
||||
last_message_id = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id
|
||||
letta_message_otids = [user_message_otid]
|
||||
|
||||
assert response, "Sending message failed"
|
||||
for chunk in response:
|
||||
@@ -104,6 +115,8 @@ def test_streaming_send_message(
|
||||
assert chunk.prompt_tokens > 1000
|
||||
assert chunk.total_tokens > 1000
|
||||
done = True
|
||||
else:
|
||||
letta_message_otids.append(chunk.otid)
|
||||
print(chunk)
|
||||
|
||||
# If stream tokens, we expect at least one inner thought
|
||||
@@ -111,3 +124,6 @@ def test_streaming_send_message(
|
||||
assert inner_thoughts_exist, "No inner thoughts found"
|
||||
assert send_message_ran, "send_message function call not found"
|
||||
assert done, "Message stream not done"
|
||||
|
||||
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id)
|
||||
assert [message.otid for message in messages] == letta_message_otids
|
||||
|
||||
Reference in New Issue
Block a user