fix: remove separate tool call id in streaming path (#2641)
This commit is contained in:
@@ -515,8 +515,7 @@ class LettaAgent(BaseAgent):
|
||||
total_tokens=interface.input_tokens + interface.output_tokens,
|
||||
),
|
||||
reasoning_content=reasoning_content,
|
||||
pre_computed_assistant_message_id=interface.letta_assistant_message_id,
|
||||
pre_computed_tool_message_id=interface.letta_tool_message_id,
|
||||
pre_computed_assistant_message_id=interface.letta_message_id,
|
||||
step_id=step_id,
|
||||
agent_step_span=agent_step_span,
|
||||
)
|
||||
@@ -811,7 +810,6 @@ class LettaAgent(BaseAgent):
|
||||
usage: UsageStatistics,
|
||||
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
|
||||
pre_computed_assistant_message_id: Optional[str] = None,
|
||||
pre_computed_tool_message_id: Optional[str] = None,
|
||||
step_id: str | None = None,
|
||||
new_in_context_messages: Optional[List[Message]] = None,
|
||||
agent_step_span: Optional["Span"] = None,
|
||||
@@ -927,7 +925,6 @@ class LettaAgent(BaseAgent):
|
||||
add_heartbeat_request_system_message=continue_stepping,
|
||||
reasoning_content=reasoning_content,
|
||||
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
|
||||
pre_computed_tool_message_id=pre_computed_tool_message_id,
|
||||
step_id=logged_step.id if logged_step else None, # TODO (cliandy): eventually move over other agent loops
|
||||
)
|
||||
|
||||
|
||||
@@ -551,7 +551,6 @@ class LettaAgentBatch(BaseAgent):
|
||||
add_heartbeat_request_system_message=False,
|
||||
reasoning_content=reasoning_content,
|
||||
pre_computed_assistant_message_id=None,
|
||||
pre_computed_tool_message_id=None,
|
||||
llm_batch_item_id=llm_batch_item_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -62,8 +62,7 @@ class AnthropicStreamingInterface:
|
||||
self.use_assistant_message = use_assistant_message
|
||||
|
||||
# Premake IDs for database writes
|
||||
self.letta_assistant_message_id = Message.generate_id()
|
||||
self.letta_tool_message_id = Message.generate_id()
|
||||
self.letta_message_id = Message.generate_id()
|
||||
|
||||
self.anthropic_mode = None
|
||||
self.message_id = None
|
||||
@@ -152,7 +151,7 @@ class AnthropicStreamingInterface:
|
||||
if not self.use_assistant_message:
|
||||
# Buffer the initial tool call message instead of yielding immediately
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_tool_message_id,
|
||||
id=self.letta_message_id,
|
||||
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
@@ -165,11 +164,11 @@ class AnthropicStreamingInterface:
|
||||
if prev_message_type and prev_message_type != "hidden_reasoning_message":
|
||||
message_index += 1
|
||||
hidden_reasoning_message = HiddenReasoningMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
id=self.letta_message_id,
|
||||
state="redacted",
|
||||
hidden_reasoning=content.data,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
self.reasoning_messages.append(hidden_reasoning_message)
|
||||
prev_message_type = hidden_reasoning_message.message_type
|
||||
@@ -206,10 +205,10 @@ class AnthropicStreamingInterface:
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
id=self.letta_message_id,
|
||||
reasoning=self.accumulated_inner_thoughts[-1],
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
@@ -233,10 +232,10 @@ class AnthropicStreamingInterface:
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
id=self.letta_message_id,
|
||||
reasoning=inner_thoughts_diff,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
@@ -250,7 +249,7 @@ class AnthropicStreamingInterface:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
for buffered_msg in self.tool_call_buffer:
|
||||
buffered_msg.otid = Message.generate_otid_from_id(self.letta_tool_message_id, message_index)
|
||||
buffered_msg.otid = Message.generate_otid_from_id(self.letta_message_id, message_index)
|
||||
prev_message_type = buffered_msg.message_type
|
||||
yield buffered_msg
|
||||
self.tool_call_buffer = []
|
||||
@@ -266,17 +265,17 @@ class AnthropicStreamingInterface:
|
||||
if prev_message_type and prev_message_type != "assistant_message":
|
||||
message_index += 1
|
||||
assistant_msg = AssistantMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
id=self.letta_message_id,
|
||||
content=[TextContent(text=send_message_diff)],
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
prev_message_type = assistant_msg.message_type
|
||||
yield assistant_msg
|
||||
else:
|
||||
# Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_tool_message_id,
|
||||
id=self.letta_message_id,
|
||||
tool_call=ToolCallDelta(
|
||||
name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json
|
||||
),
|
||||
@@ -285,7 +284,7 @@ class AnthropicStreamingInterface:
|
||||
if self.inner_thoughts_complete:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
tool_call_msg.otid = Message.generate_otid_from_id(self.letta_tool_message_id, message_index)
|
||||
tool_call_msg.otid = Message.generate_otid_from_id(self.letta_message_id, message_index)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
else:
|
||||
@@ -303,11 +302,11 @@ class AnthropicStreamingInterface:
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
id=self.letta_message_id,
|
||||
source="reasoner_model",
|
||||
reasoning=delta.thinking,
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
@@ -322,12 +321,12 @@ class AnthropicStreamingInterface:
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
id=self.letta_message_id,
|
||||
source="reasoner_model",
|
||||
reasoning="",
|
||||
date=datetime.now(timezone.utc).isoformat(),
|
||||
signature=delta.signature,
|
||||
otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index),
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
)
|
||||
self.reasoning_messages.append(reasoning_message)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
|
||||
@@ -190,7 +190,6 @@ def create_letta_messages_from_llm_response(
|
||||
add_heartbeat_request_system_message: bool = False,
|
||||
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
|
||||
pre_computed_assistant_message_id: Optional[str] = None,
|
||||
pre_computed_tool_message_id: Optional[str] = None,
|
||||
llm_batch_item_id: Optional[str] = None,
|
||||
step_id: str | None = None,
|
||||
) -> List[Message]:
|
||||
@@ -245,8 +244,6 @@ def create_letta_messages_from_llm_response(
|
||||
)
|
||||
],
|
||||
)
|
||||
if pre_computed_tool_message_id:
|
||||
tool_message.id = pre_computed_tool_message_id
|
||||
messages.append(tool_message)
|
||||
|
||||
if add_heartbeat_request_system_message:
|
||||
|
||||
@@ -182,6 +182,9 @@ async def test_pinecone_tool(client: AsyncLetta) -> None:
|
||||
stream_message = response_messages_from_stream[idx]
|
||||
db_message = response_messages_from_db[idx]
|
||||
assert stream_message.message_type == db_message.message_type
|
||||
print("message type:", stream_message.message_type)
|
||||
print("stream message:", stream_message.model_dump_json(indent=4))
|
||||
print("db message:", db_message.model_dump_json(indent=4))
|
||||
assert stream_message.id == db_message.id
|
||||
assert stream_message.otid == db_message.otid
|
||||
|
||||
|
||||
Reference in New Issue
Block a user