feat: add run id to streamed messages (#5037)

This commit is contained in:
cthomas
2025-09-30 16:54:00 -07:00
committed by Caren Thomas
parent 6255d59bba
commit 67f8e46619
6 changed files with 60 additions and 1 deletions

View File

@@ -66,9 +66,11 @@ class AnthropicStreamingInterface:
use_assistant_message: bool = False,
put_inner_thoughts_in_kwarg: bool = False,
requires_approval_tools: list = [],
run_id: str | None = None,
):
self.json_parser: JSONParser = PydanticJSONParser()
self.use_assistant_message = use_assistant_message
self.run_id = run_id
# Premake IDs for database writes
self.letta_message_id = Message.generate_id()
@@ -280,6 +282,7 @@ class AnthropicStreamingInterface:
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
)
self.tool_call_buffer.append(tool_call_msg)
elif isinstance(content, BetaThinkingBlock):
@@ -295,6 +298,7 @@ class AnthropicStreamingInterface:
hidden_reasoning=content.data,
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
)
self.reasoning_messages.append(hidden_reasoning_message)
prev_message_type = hidden_reasoning_message.message_type
@@ -340,6 +344,7 @@ class AnthropicStreamingInterface:
reasoning=self.accumulated_inner_thoughts[-1],
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
@@ -367,6 +372,7 @@ class AnthropicStreamingInterface:
reasoning=inner_thoughts_diff,
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
@@ -397,6 +403,7 @@ class AnthropicStreamingInterface:
tool_call_id=self.tool_call_id,
arguments=tool_call_args,
),
run_id=self.run_id,
)
prev_message_type = approval_msg.message_type
yield approval_msg
@@ -420,6 +427,7 @@ class AnthropicStreamingInterface:
tool_call_id=self.tool_call_id,
arguments=tool_call_args,
),
run_id=self.run_id,
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
@@ -440,6 +448,7 @@ class AnthropicStreamingInterface:
content=[TextContent(text=send_message_diff)],
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
)
prev_message_type = assistant_msg.message_type
yield assistant_msg
@@ -450,12 +459,14 @@ class AnthropicStreamingInterface:
id=self.letta_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json),
date=datetime.now(timezone.utc).isoformat(),
run_id=self.run_id,
)
else:
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json),
date=datetime.now(timezone.utc).isoformat(),
run_id=self.run_id,
)
if self.inner_thoughts_complete:
if prev_message_type and prev_message_type != "tool_call_message":
@@ -483,6 +494,7 @@ class AnthropicStreamingInterface:
reasoning=delta.thinking,
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
@@ -503,6 +515,7 @@ class AnthropicStreamingInterface:
date=datetime.now(timezone.utc).isoformat(),
signature=delta.signature,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
@@ -536,8 +549,10 @@ class SimpleAnthropicStreamingInterface:
def __init__(
self,
requires_approval_tools: list = [],
run_id: str | None = None,
):
self.json_parser: JSONParser = PydanticJSONParser()
self.run_id = run_id
# Premake IDs for database writes
self.letta_message_id = Message.generate_id()
@@ -748,6 +763,7 @@ class SimpleAnthropicStreamingInterface:
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
)
else:
if prev_message_type and prev_message_type != "tool_call_message":
@@ -757,6 +773,7 @@ class SimpleAnthropicStreamingInterface:
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
@@ -777,6 +794,7 @@ class SimpleAnthropicStreamingInterface:
hidden_reasoning=content.data,
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
)
self.reasoning_messages.append(hidden_reasoning_message)
@@ -800,6 +818,7 @@ class SimpleAnthropicStreamingInterface:
content=delta.text,
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
)
# self.assistant_messages.append(assistant_msg)
self.reasoning_messages.append(assistant_msg)
@@ -822,6 +841,7 @@ class SimpleAnthropicStreamingInterface:
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json),
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
)
else:
if prev_message_type and prev_message_type != "tool_call_message":
@@ -831,6 +851,7 @@ class SimpleAnthropicStreamingInterface:
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json),
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
)
yield tool_call_msg
@@ -850,6 +871,7 @@ class SimpleAnthropicStreamingInterface:
reasoning=delta.thinking,
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type
@@ -871,6 +893,7 @@ class SimpleAnthropicStreamingInterface:
date=datetime.now(timezone.utc).isoformat(),
signature=delta.signature,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
)
self.reasoning_messages.append(reasoning_message)
prev_message_type = reasoning_message.message_type