feat: add approval create input to messages endpoints [LET-4110] (#4309)
* feat: add approval create input to messages endpoints * rename discriminator tag * add base class with default * add field validator * exclude new type field from agent file schema
This commit is contained in:
@@ -9,7 +9,7 @@ from letta.schemas.agent import AgentState
|
||||
from letta.schemas.letta_message import MessageType
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.message import Message, MessageCreate, MessageCreateBase
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
@@ -121,7 +121,7 @@ async def _prepare_in_context_messages_async(
|
||||
|
||||
|
||||
async def _prepare_in_context_messages_no_persist_async(
|
||||
input_messages: List[MessageCreate],
|
||||
input_messages: List[MessageCreateBase],
|
||||
agent_state: AgentState,
|
||||
message_manager: MessageManager,
|
||||
actor: User,
|
||||
|
||||
@@ -40,7 +40,7 @@ from letta.schemas.letta_message_content import OmittedReasoningContent, Reasoni
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.message import Message, MessageCreateBase
|
||||
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
||||
from letta.schemas.provider_trace import ProviderTraceCreate
|
||||
from letta.schemas.step import StepProgression
|
||||
@@ -164,7 +164,7 @@ class LettaAgent(BaseAgent):
|
||||
@trace_method
|
||||
async def step(
|
||||
self,
|
||||
input_messages: list[MessageCreate],
|
||||
input_messages: list[MessageCreateBase],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
run_id: str | None = None,
|
||||
use_assistant_message: bool = True,
|
||||
@@ -203,7 +203,7 @@ class LettaAgent(BaseAgent):
|
||||
@trace_method
|
||||
async def step_stream_no_tokens(
|
||||
self,
|
||||
input_messages: list[MessageCreate],
|
||||
input_messages: list[MessageCreateBase],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
@@ -501,7 +501,7 @@ class LettaAgent(BaseAgent):
|
||||
async def _step(
|
||||
self,
|
||||
agent_state: AgentState,
|
||||
input_messages: list[MessageCreate],
|
||||
input_messages: list[MessageCreateBase],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
run_id: str | None = None,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
@@ -807,7 +807,7 @@ class LettaAgent(BaseAgent):
|
||||
@trace_method
|
||||
async def step_stream(
|
||||
self,
|
||||
input_messages: list[MessageCreate],
|
||||
input_messages: list[MessageCreateBase],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, HttpUrl
|
||||
from pydantic import BaseModel, Field, HttpUrl, field_validator
|
||||
|
||||
from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.schemas.letta_message import MessageType
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.message import MessageCreateUnion
|
||||
|
||||
|
||||
class LettaRequest(BaseModel):
|
||||
messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.")
|
||||
messages: List[MessageCreateUnion] = Field(..., description="The messages to be sent to the agent.")
|
||||
max_steps: int = Field(
|
||||
default=DEFAULT_MAX_STEPS,
|
||||
description="Maximum number of steps the agent should take to process the request.",
|
||||
@@ -36,6 +36,16 @@ class LettaRequest(BaseModel):
|
||||
description="If set to True, enables reasoning before responses or tool calls from the agent.",
|
||||
)
|
||||
|
||||
@field_validator("messages", mode="before")
|
||||
@classmethod
|
||||
def add_default_type_to_messages(cls, v):
|
||||
"""Add default 'message' type for backwards compatibility with older versions of SDK clients that don't send it"""
|
||||
if isinstance(v, list):
|
||||
for item in v:
|
||||
if isinstance(item, dict) and "type" not in item:
|
||||
item["type"] = "message"
|
||||
return v
|
||||
|
||||
|
||||
class LettaStreamingRequest(LettaRequest):
|
||||
stream_tokens: bool = Field(
|
||||
|
||||
@@ -7,10 +7,11 @@ import uuid
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall, Function as OpenAIFunction
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, TOOL_CALL_ID_MAX_LEN
|
||||
from letta.helpers.datetime_helpers import get_utc_time, is_utc_datetime
|
||||
@@ -65,13 +66,19 @@ def add_inner_thoughts_to_tool_call(
|
||||
raise e
|
||||
|
||||
|
||||
class BaseMessage(OrmMetadataBase):
|
||||
__id_prefix__ = "message"
|
||||
class MessageCreateType(str, Enum):
|
||||
message = "message"
|
||||
approval = "approval"
|
||||
|
||||
|
||||
class MessageCreate(BaseModel):
|
||||
class MessageCreateBase(BaseModel):
|
||||
type: MessageCreateType = Field(..., description="The message type to be created.")
|
||||
|
||||
|
||||
class MessageCreate(MessageCreateBase):
|
||||
"""Request to create a message"""
|
||||
|
||||
type: Literal[MessageCreateType.message] = Field(default=MessageCreateType.message, description="The message type to be created.")
|
||||
# In the simplified format, only allow simple roles
|
||||
role: Literal[
|
||||
MessageRole.user,
|
||||
@@ -97,6 +104,37 @@ class MessageCreate(BaseModel):
|
||||
return data
|
||||
|
||||
|
||||
class ApprovalCreate(MessageCreateBase):
|
||||
"""Input to approve or deny a tool call request"""
|
||||
|
||||
type: Literal[MessageCreateType.approval] = Field(default=MessageCreateType.approval, description="The message type to be created.")
|
||||
approve: bool = Field(..., description="Whether the tool has been approved")
|
||||
approval_request_id: str = Field(..., description="The message ID of the approval request")
|
||||
reason: Optional[str] = Field(None, description="An optional explanation for the provided approval status")
|
||||
|
||||
|
||||
MessageCreateUnion = Annotated[
|
||||
Union[MessageCreate, ApprovalCreate],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
def create_message_create_union_schema():
|
||||
return {
|
||||
"oneOf": [
|
||||
{"$ref": "#/components/schemas/MessageCreate"},
|
||||
{"$ref": "#/components/schemas/ApprovalCreate"},
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"message": "#/components/schemas/MessageCreate",
|
||||
"approval": "#/components/schemas/ApprovalCreate",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class MessageUpdate(BaseModel):
|
||||
"""Request to update a message"""
|
||||
|
||||
@@ -125,6 +163,10 @@ class MessageUpdate(BaseModel):
|
||||
return data
|
||||
|
||||
|
||||
class BaseMessage(OrmMetadataBase):
|
||||
__id_prefix__ = "message"
|
||||
|
||||
|
||||
class Message(BaseMessage):
|
||||
"""
|
||||
Letta's internal representation of a message. Includes methods to convert to/from LLM provider formats.
|
||||
|
||||
@@ -29,6 +29,7 @@ from letta.schemas.letta_message_content import (
|
||||
create_letta_user_message_content_union_schema,
|
||||
)
|
||||
from letta.schemas.letta_ping import create_letta_ping_schema
|
||||
from letta.schemas.message import create_message_create_union_schema
|
||||
from letta.server.constants import REST_DEFAULT_PORT
|
||||
from letta.server.db import db_registry
|
||||
|
||||
@@ -65,6 +66,7 @@ def generate_openapi_schema(app: FastAPI):
|
||||
letta_docs["paths"] = {k: v for k, v in letta_docs["paths"].items() if not k.startswith("/openai")}
|
||||
letta_docs["info"]["title"] = "Letta API"
|
||||
letta_docs["components"]["schemas"]["LettaMessageUnion"] = create_letta_message_union_schema()
|
||||
letta_docs["components"]["schemas"]["MessageCreateUnion"] = create_message_create_union_schema()
|
||||
letta_docs["components"]["schemas"]["LettaMessageContentUnion"] = create_letta_message_content_union_schema()
|
||||
letta_docs["components"]["schemas"]["LettaAssistantMessageContentUnion"] = create_letta_assistant_message_content_union_schema()
|
||||
letta_docs["components"]["schemas"]["LettaUserMessageContentUnion"] = create_letta_user_message_content_union_schema()
|
||||
|
||||
@@ -668,7 +668,7 @@ class AgentSerializationManager:
|
||||
messages = []
|
||||
for message_schema in agent_schema.messages:
|
||||
# Convert MessageSchema back to Message, setting agent_id to new DB ID
|
||||
message_data = message_schema.model_dump(exclude={"id"})
|
||||
message_data = message_schema.model_dump(exclude={"id", "type"})
|
||||
message_data["agent_id"] = agent_db_id # Remap agent_id to new database ID
|
||||
message_obj = Message(**message_data)
|
||||
messages.append(message_obj)
|
||||
|
||||
Reference in New Issue
Block a user