feat: add approval create input to messages endpoints [LET-4110] (#4309)

* feat: add approval create input to messages endpoints

* rename discriminator tag

* add base class with default

* add field validator

* exclude new type field from agent file schema
This commit is contained in:
cthomas
2025-08-29 13:16:03 -07:00
committed by GitHub
parent 26309264a4
commit bfdb586f74
6 changed files with 70 additions and 16 deletions

View File

@@ -9,7 +9,7 @@ from letta.schemas.agent import AgentState
from letta.schemas.letta_message import MessageType
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message, MessageCreate
from letta.schemas.message import Message, MessageCreate, MessageCreateBase
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
@@ -121,7 +121,7 @@ async def _prepare_in_context_messages_async(
async def _prepare_in_context_messages_no_persist_async(
input_messages: List[MessageCreate],
input_messages: List[MessageCreateBase],
agent_state: AgentState,
message_manager: MessageManager,
actor: User,

View File

@@ -40,7 +40,7 @@ from letta.schemas.letta_message_content import OmittedReasoningContent, Reasoni
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message, MessageCreate
from letta.schemas.message import Message, MessageCreateBase
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
from letta.schemas.provider_trace import ProviderTraceCreate
from letta.schemas.step import StepProgression
@@ -164,7 +164,7 @@ class LettaAgent(BaseAgent):
@trace_method
async def step(
self,
input_messages: list[MessageCreate],
input_messages: list[MessageCreateBase],
max_steps: int = DEFAULT_MAX_STEPS,
run_id: str | None = None,
use_assistant_message: bool = True,
@@ -203,7 +203,7 @@ class LettaAgent(BaseAgent):
@trace_method
async def step_stream_no_tokens(
self,
input_messages: list[MessageCreate],
input_messages: list[MessageCreateBase],
max_steps: int = DEFAULT_MAX_STEPS,
use_assistant_message: bool = True,
request_start_timestamp_ns: int | None = None,
@@ -501,7 +501,7 @@ class LettaAgent(BaseAgent):
async def _step(
self,
agent_state: AgentState,
input_messages: list[MessageCreate],
input_messages: list[MessageCreateBase],
max_steps: int = DEFAULT_MAX_STEPS,
run_id: str | None = None,
request_start_timestamp_ns: int | None = None,
@@ -807,7 +807,7 @@ class LettaAgent(BaseAgent):
@trace_method
async def step_stream(
self,
input_messages: list[MessageCreate],
input_messages: list[MessageCreateBase],
max_steps: int = DEFAULT_MAX_STEPS,
use_assistant_message: bool = True,
request_start_timestamp_ns: int | None = None,

View File

@@ -1,14 +1,14 @@
from typing import List, Optional
from pydantic import BaseModel, Field, HttpUrl
from pydantic import BaseModel, Field, HttpUrl, field_validator
from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.schemas.letta_message import MessageType
from letta.schemas.message import MessageCreate
from letta.schemas.message import MessageCreateUnion
class LettaRequest(BaseModel):
messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.")
messages: List[MessageCreateUnion] = Field(..., description="The messages to be sent to the agent.")
max_steps: int = Field(
default=DEFAULT_MAX_STEPS,
description="Maximum number of steps the agent should take to process the request.",
@@ -36,6 +36,16 @@ class LettaRequest(BaseModel):
description="If set to True, enables reasoning before responses or tool calls from the agent.",
)
@field_validator("messages", mode="before")
@classmethod
def add_default_type_to_messages(cls, v):
"""Add default 'message' type for backwards compatibility with older versions of SDK clients that don't send it"""
if isinstance(v, list):
for item in v:
if isinstance(item, dict) and "type" not in item:
item["type"] = "message"
return v
class LettaStreamingRequest(LettaRequest):
stream_tokens: bool = Field(

View File

@@ -7,10 +7,11 @@ import uuid
import warnings
from collections import OrderedDict
from datetime import datetime, timezone
from typing import Any, Dict, List, Literal, Optional, Union
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall, Function as OpenAIFunction
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, TOOL_CALL_ID_MAX_LEN
from letta.helpers.datetime_helpers import get_utc_time, is_utc_datetime
@@ -65,13 +66,19 @@ def add_inner_thoughts_to_tool_call(
raise e
class BaseMessage(OrmMetadataBase):
__id_prefix__ = "message"
class MessageCreateType(str, Enum):
message = "message"
approval = "approval"
class MessageCreate(BaseModel):
class MessageCreateBase(BaseModel):
type: MessageCreateType = Field(..., description="The message type to be created.")
class MessageCreate(MessageCreateBase):
"""Request to create a message"""
type: Literal[MessageCreateType.message] = Field(default=MessageCreateType.message, description="The message type to be created.")
# In the simplified format, only allow simple roles
role: Literal[
MessageRole.user,
@@ -97,6 +104,37 @@ class MessageCreate(BaseModel):
return data
class ApprovalCreate(MessageCreateBase):
"""Input to approve or deny a tool call request"""
type: Literal[MessageCreateType.approval] = Field(default=MessageCreateType.approval, description="The message type to be created.")
approve: bool = Field(..., description="Whether the tool has been approved")
approval_request_id: str = Field(..., description="The message ID of the approval request")
reason: Optional[str] = Field(None, description="An optional explanation for the provided approval status")
MessageCreateUnion = Annotated[
Union[MessageCreate, ApprovalCreate],
Field(discriminator="type"),
]
def create_message_create_union_schema():
return {
"oneOf": [
{"$ref": "#/components/schemas/MessageCreate"},
{"$ref": "#/components/schemas/ApprovalCreate"},
],
"discriminator": {
"propertyName": "type",
"mapping": {
"message": "#/components/schemas/MessageCreate",
"approval": "#/components/schemas/ApprovalCreate",
},
},
}
class MessageUpdate(BaseModel):
"""Request to update a message"""
@@ -125,6 +163,10 @@ class MessageUpdate(BaseModel):
return data
class BaseMessage(OrmMetadataBase):
__id_prefix__ = "message"
class Message(BaseMessage):
"""
Letta's internal representation of a message. Includes methods to convert to/from LLM provider formats.

View File

@@ -29,6 +29,7 @@ from letta.schemas.letta_message_content import (
create_letta_user_message_content_union_schema,
)
from letta.schemas.letta_ping import create_letta_ping_schema
from letta.schemas.message import create_message_create_union_schema
from letta.server.constants import REST_DEFAULT_PORT
from letta.server.db import db_registry
@@ -65,6 +66,7 @@ def generate_openapi_schema(app: FastAPI):
letta_docs["paths"] = {k: v for k, v in letta_docs["paths"].items() if not k.startswith("/openai")}
letta_docs["info"]["title"] = "Letta API"
letta_docs["components"]["schemas"]["LettaMessageUnion"] = create_letta_message_union_schema()
letta_docs["components"]["schemas"]["MessageCreateUnion"] = create_message_create_union_schema()
letta_docs["components"]["schemas"]["LettaMessageContentUnion"] = create_letta_message_content_union_schema()
letta_docs["components"]["schemas"]["LettaAssistantMessageContentUnion"] = create_letta_assistant_message_content_union_schema()
letta_docs["components"]["schemas"]["LettaUserMessageContentUnion"] = create_letta_user_message_content_union_schema()

View File

@@ -668,7 +668,7 @@ class AgentSerializationManager:
messages = []
for message_schema in agent_schema.messages:
# Convert MessageSchema back to Message, setting agent_id to new DB ID
message_data = message_schema.model_dump(exclude={"id"})
message_data = message_schema.model_dump(exclude={"id", "type"})
message_data["agent_id"] = agent_db_id # Remap agent_id to new database ID
message_obj = Message(**message_data)
messages.append(message_obj)