* auto fixes * auto fix pt2 and transitive deps and undefined var checking locals() * manual fixes (ignored or letta-code fixed) * fix circular import
251 lines
10 KiB
Python
251 lines
10 KiB
Python
import html
|
|
import json
|
|
import re
|
|
from datetime import datetime
|
|
from typing import Any, List, Literal, Optional, Union
|
|
|
|
from pydantic import BaseModel, Field, RootModel
|
|
|
|
from letta.helpers.json_helpers import json_dumps
|
|
from letta.schemas.enums import JobStatus
|
|
from letta.schemas.letta_message import (
|
|
ApprovalRequestMessage,
|
|
ApprovalResponseMessage,
|
|
AssistantMessage,
|
|
HiddenReasoningMessage,
|
|
LettaErrorMessage,
|
|
LettaMessageUnion,
|
|
LettaPing,
|
|
ReasoningMessage,
|
|
SystemMessage,
|
|
ToolCallMessage,
|
|
ToolReturnMessage,
|
|
UserMessage,
|
|
)
|
|
from letta.schemas.letta_stop_reason import LettaStopReason
|
|
from letta.schemas.message import Message
|
|
from letta.schemas.openai.chat_completion_response import ChoiceLogprobs
|
|
from letta.schemas.usage import LettaUsageStatistics
|
|
|
|
# TODO: consider moving into own file
|
|
|
|
|
|
class TurnTokenData(BaseModel):
|
|
"""Token data for a single LLM generation turn in a multi-turn agent interaction.
|
|
|
|
Used for RL training to track token IDs and logprobs across all LLM calls,
|
|
not just the final one. Tool results are included so the client can tokenize
|
|
them with loss_mask=0 (non-trainable).
|
|
"""
|
|
|
|
role: Literal["assistant", "tool"] = Field(
|
|
..., description="Role of this turn: 'assistant' for LLM generations (trainable), 'tool' for tool results (non-trainable)."
|
|
)
|
|
output_ids: Optional[List[int]] = Field(None, description="Token IDs from SGLang native endpoint. Only present for assistant turns.")
|
|
output_token_logprobs: Optional[List[List[Any]]] = Field(
|
|
None, description="Logprobs from SGLang: [[logprob, token_id, top_logprob_or_null], ...]. Only present for assistant turns."
|
|
)
|
|
content: Optional[str] = Field(None, description="Text content. For tool turns, client tokenizes this with loss_mask=0.")
|
|
tool_name: Optional[str] = Field(None, description="Name of the tool called. Only present for tool turns.")
|
|
|
|
|
|
class LettaResponse(BaseModel):
|
|
"""
|
|
Response object from an agent interaction, consisting of the new messages generated by the agent and usage statistics.
|
|
The type of the returned messages can be either `Message` or `LettaMessage`, depending on what was specified in the request.
|
|
|
|
Attributes:
|
|
messages (List[Union[Message, LettaMessage]]): The messages returned by the agent.
|
|
usage (LettaUsageStatistics): The usage statistics
|
|
"""
|
|
|
|
messages: List[LettaMessageUnion] = Field(
|
|
...,
|
|
description="The messages returned by the agent.",
|
|
json_schema_extra={
|
|
"items": {
|
|
"$ref": "#/components/schemas/LettaMessageUnion",
|
|
}
|
|
},
|
|
)
|
|
stop_reason: LettaStopReason = Field(
|
|
...,
|
|
description="The stop reason from Letta indicating why agent loop stopped execution.",
|
|
)
|
|
usage: LettaUsageStatistics = Field(
|
|
...,
|
|
description="The usage statistics of the agent.",
|
|
)
|
|
logprobs: Optional[ChoiceLogprobs] = Field(
|
|
None,
|
|
description="Log probabilities of the output tokens from the last LLM call. Only present if return_logprobs was enabled.",
|
|
)
|
|
turns: Optional[List[TurnTokenData]] = Field(
|
|
None,
|
|
description="Token data for all LLM generations in multi-turn agent interaction. "
|
|
"Includes token IDs and logprobs for each assistant turn, plus tool result content. "
|
|
"Only present if return_token_ids was enabled. Used for RL training with loss masking.",
|
|
)
|
|
|
|
def __str__(self):
|
|
return json_dumps(
|
|
{
|
|
"messages": [message.model_dump() for message in self.messages],
|
|
# Assume `Message` and `LettaMessage` have a `dict()` method
|
|
"usage": self.usage.model_dump(), # Assume `LettaUsageStatistics` has a `dict()` method
|
|
},
|
|
indent=4,
|
|
)
|
|
|
|
def _repr_html_(self):
|
|
def get_formatted_content(msg):
|
|
if msg.message_type == "internal_monologue":
|
|
return f'<div class="content"><span class="internal-monologue">{html.escape(msg.internal_monologue)}</span></div>'
|
|
if msg.message_type == "reasoning_message":
|
|
return f'<div class="content"><span class="internal-monologue">{html.escape(msg.reasoning)}</span></div>'
|
|
elif msg.message_type == "function_call":
|
|
args = format_json(msg.function_call.arguments)
|
|
return f'<div class="content"><span class="function-name">{html.escape(msg.function_call.name)}</span>({args})</div>'
|
|
elif msg.message_type == "tool_call_message":
|
|
args = format_json(msg.tool_call.arguments)
|
|
return f'<div class="content"><span class="function-name">{html.escape(msg.tool_call.name)}</span>({args})</div>'
|
|
elif msg.message_type == "function_return":
|
|
return_value = format_json(msg.function_return)
|
|
# return f'<div class="status-line">Status: {html.escape(msg.status)}</div><div class="content">{return_value}</div>'
|
|
return f'<div class="content">{return_value}</div>'
|
|
elif msg.message_type == "tool_return_message":
|
|
return_value = format_json(msg.tool_return)
|
|
# return f'<div class="status-line">Status: {html.escape(msg.status)}</div><div class="content">{return_value}</div>'
|
|
return f'<div class="content">{return_value}</div>'
|
|
elif msg.message_type == "user_message":
|
|
if is_json(msg.message):
|
|
return f'<div class="content">{format_json(msg.message)}</div>'
|
|
else:
|
|
return f'<div class="content">{html.escape(msg.message)}</div>'
|
|
elif msg.message_type in ["assistant_message", "system_message"]:
|
|
return f'<div class="content">{html.escape(msg.message)}</div>'
|
|
else:
|
|
return f'<div class="content">{html.escape(str(msg))}</div>'
|
|
|
|
def is_json(string):
|
|
try:
|
|
json.loads(string)
|
|
return True
|
|
except ValueError:
|
|
return False
|
|
|
|
def format_json(json_str):
|
|
try:
|
|
parsed = json.loads(json_str)
|
|
formatted = json.dumps(parsed, indent=2, ensure_ascii=False)
|
|
formatted = formatted.replace("&", "&").replace("<", "<").replace(">", ">")
|
|
formatted = formatted.replace("\n", "<br>").replace(" ", " ")
|
|
formatted = re.sub(r'(".*?"):', r'<span class="json-key">\1</span>:', formatted)
|
|
formatted = re.sub(r': (".*?")', r': <span class="json-string">\1</span>', formatted)
|
|
formatted = re.sub(r": (\d+)", r': <span class="json-number">\1</span>', formatted)
|
|
formatted = re.sub(r": (true|false)", r': <span class="json-boolean">\1</span>', formatted)
|
|
return formatted
|
|
except json.JSONDecodeError:
|
|
return html.escape(json_str)
|
|
|
|
html_output = """
|
|
<style>
|
|
.message-container, .usage-container {
|
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
|
max-width: 800px;
|
|
margin: 20px auto;
|
|
background-color: #1e1e1e;
|
|
border-radius: 8px;
|
|
overflow: hidden;
|
|
color: #d4d4d4;
|
|
}
|
|
.message, .usage-stats {
|
|
padding: 10px 15px;
|
|
border-bottom: 1px solid #3a3a3a;
|
|
}
|
|
.message:last-child, .usage-stats:last-child {
|
|
border-bottom: none;
|
|
}
|
|
.title {
|
|
font-weight: bold;
|
|
margin-bottom: 5px;
|
|
color: #ffffff;
|
|
text-transform: uppercase;
|
|
font-size: 0.9em;
|
|
}
|
|
.content {
|
|
background-color: #2d2d2d;
|
|
border-radius: 4px;
|
|
padding: 5px 10px;
|
|
font-family: 'Consolas', 'Courier New', monospace;
|
|
white-space: pre-wrap;
|
|
}
|
|
.json-key, .function-name, .json-boolean { color: #9cdcfe; }
|
|
.json-string { color: #ce9178; }
|
|
.json-number { color: #b5cea8; }
|
|
.internal-monologue { font-style: italic; }
|
|
</style>
|
|
<div class="message-container">
|
|
"""
|
|
|
|
for msg in self.messages:
|
|
content = get_formatted_content(msg)
|
|
title = msg.message_type.replace("_", " ").upper()
|
|
html_output += f"""
|
|
<div class="message">
|
|
<div class="title">{title}</div>
|
|
{content}
|
|
</div>
|
|
"""
|
|
html_output += "</div>"
|
|
|
|
# Formatting the usage statistics
|
|
usage_html = json.dumps(self.usage.model_dump(), indent=2)
|
|
html_output += f"""
|
|
<div class="usage-container">
|
|
<div class="usage-stats">
|
|
<div class="title">USAGE STATISTICS</div>
|
|
<div class="content">{format_json(usage_html)}</div>
|
|
</div>
|
|
</div>
|
|
"""
|
|
|
|
return html_output
|
|
|
|
|
|
# The streaming response can be any of the individual message types, plus metadata types
|
|
class LettaStreamingResponse(RootModel):
|
|
"""
|
|
Streaming response type for Server-Sent Events (SSE) endpoints.
|
|
Each event in the stream will be one of these types.
|
|
"""
|
|
|
|
root: Union[
|
|
SystemMessage,
|
|
UserMessage,
|
|
ReasoningMessage,
|
|
HiddenReasoningMessage,
|
|
ToolCallMessage,
|
|
ToolReturnMessage,
|
|
AssistantMessage,
|
|
ApprovalRequestMessage,
|
|
ApprovalResponseMessage,
|
|
LettaPing,
|
|
LettaErrorMessage,
|
|
LettaStopReason,
|
|
LettaUsageStatistics,
|
|
] = Field(..., discriminator="message_type")
|
|
|
|
|
|
class LettaBatchResponse(BaseModel):
|
|
letta_batch_id: str = Field(..., description="A unique identifier for the Letta batch request.")
|
|
last_llm_batch_id: str = Field(..., description="A unique identifier for the most recent model provider batch request.")
|
|
status: JobStatus = Field(..., description="The current status of the batch request.")
|
|
agent_count: int = Field(..., description="The number of agents in the batch request.")
|
|
last_polled_at: datetime = Field(..., description="The timestamp when the batch was last polled for updates.")
|
|
created_at: datetime = Field(..., description="The timestamp when the batch request was created.")
|
|
|
|
|
|
class LettaBatchMessages(BaseModel):
|
|
messages: List[Message]
|