diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index 38ca80be..43dd3184 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -9,7 +9,7 @@ from letta.schemas.agent import AgentState from letta.schemas.letta_message import MessageType from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType -from letta.schemas.message import Message, MessageCreate +from letta.schemas.message import Message, MessageCreate, MessageCreateBase from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User @@ -121,7 +121,7 @@ async def _prepare_in_context_messages_async( async def _prepare_in_context_messages_no_persist_async( - input_messages: List[MessageCreate], + input_messages: List[MessageCreateBase], agent_state: AgentState, message_manager: MessageManager, actor: User, diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 7e0019ea..6343764f 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -40,7 +40,7 @@ from letta.schemas.letta_message_content import OmittedReasoningContent, Reasoni from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import Message, MessageCreate +from letta.schemas.message import Message, MessageCreateBase from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics from letta.schemas.provider_trace import ProviderTraceCreate from letta.schemas.step import StepProgression @@ -164,7 +164,7 @@ class LettaAgent(BaseAgent): @trace_method async def step( self, - input_messages: list[MessageCreate], + input_messages: list[MessageCreateBase], max_steps: int = DEFAULT_MAX_STEPS, run_id: str | None = None, use_assistant_message: bool = True, @@ -203,7 +203,7 @@ class LettaAgent(BaseAgent): @trace_method async def step_stream_no_tokens( self, - input_messages: list[MessageCreate], + input_messages: list[MessageCreateBase], max_steps: int = DEFAULT_MAX_STEPS, use_assistant_message: bool = True, request_start_timestamp_ns: int | None = None, @@ -501,7 +501,7 @@ class LettaAgent(BaseAgent): async def _step( self, agent_state: AgentState, - input_messages: list[MessageCreate], + input_messages: list[MessageCreateBase], max_steps: int = DEFAULT_MAX_STEPS, run_id: str | None = None, request_start_timestamp_ns: int | None = None, @@ -807,7 +807,7 @@ class LettaAgent(BaseAgent): @trace_method async def step_stream( self, - input_messages: list[MessageCreate], + input_messages: list[MessageCreateBase], max_steps: int = DEFAULT_MAX_STEPS, use_assistant_message: bool = True, request_start_timestamp_ns: int | None = None, diff --git a/letta/schemas/letta_request.py b/letta/schemas/letta_request.py index 5da3bca6..f02d1635 100644 --- a/letta/schemas/letta_request.py +++ b/letta/schemas/letta_request.py @@ -1,14 +1,14 @@ from typing import List, Optional -from pydantic import BaseModel, Field, HttpUrl +from pydantic import BaseModel, Field, HttpUrl, field_validator from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.letta_message import MessageType -from letta.schemas.message import MessageCreate +from letta.schemas.message import MessageCreateUnion class LettaRequest(BaseModel): - messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.") + messages: List[MessageCreateUnion] = Field(..., description="The messages to be sent to the agent.") max_steps: int = Field( default=DEFAULT_MAX_STEPS, description="Maximum number of steps the agent should take to process the request.", @@ -36,6 +36,16 @@ class LettaRequest(BaseModel): description="If set to True, enables reasoning before responses or tool calls from the agent.", ) + @field_validator("messages", mode="before") + @classmethod + def add_default_type_to_messages(cls, v): + """Add default 'message' type for backwards compatibility with older versions of SDK clients that don't send it""" + if isinstance(v, list): + for item in v: + if isinstance(item, dict) and "type" not in item: + item["type"] = "message" + return v + class LettaStreamingRequest(LettaRequest): stream_tokens: bool = Field( diff --git a/letta/schemas/message.py b/letta/schemas/message.py index f2cbf896..badd095a 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -7,10 +7,11 @@ import uuid import warnings from collections import OrderedDict from datetime import datetime, timezone -from typing import Any, Dict, List, Literal, Optional, Union +from enum import Enum +from typing import Annotated, Any, Dict, List, Literal, Optional, Union from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall, Function as OpenAIFunction -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, field_validator, model_validator from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, TOOL_CALL_ID_MAX_LEN from letta.helpers.datetime_helpers import get_utc_time, is_utc_datetime @@ -65,13 +66,19 @@ def add_inner_thoughts_to_tool_call( raise e -class BaseMessage(OrmMetadataBase): - __id_prefix__ = "message" +class MessageCreateType(str, Enum): + message = "message" + approval = "approval" -class MessageCreate(BaseModel): +class MessageCreateBase(BaseModel): + type: MessageCreateType = Field(..., description="The message type to be created.") + + +class MessageCreate(MessageCreateBase): """Request to create a message""" + type: Literal[MessageCreateType.message] = Field(default=MessageCreateType.message, description="The message type to be created.") # In the simplified format, only allow simple roles role: Literal[ MessageRole.user, @@ -97,6 +104,37 @@ class MessageCreate(BaseModel): return data +class ApprovalCreate(MessageCreateBase): + """Input to approve or deny a tool call request""" + + type: Literal[MessageCreateType.approval] = Field(default=MessageCreateType.approval, description="The message type to be created.") + approve: bool = Field(..., description="Whether the tool has been approved") + approval_request_id: str = Field(..., description="The message ID of the approval request") + reason: Optional[str] = Field(None, description="An optional explanation for the provided approval status") + + +MessageCreateUnion = Annotated[ + Union[MessageCreate, ApprovalCreate], + Field(discriminator="type"), +] + + +def create_message_create_union_schema(): + return { + "oneOf": [ + {"$ref": "#/components/schemas/MessageCreate"}, + {"$ref": "#/components/schemas/ApprovalCreate"}, + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "message": "#/components/schemas/MessageCreate", + "approval": "#/components/schemas/ApprovalCreate", + }, + }, + } + + class MessageUpdate(BaseModel): """Request to update a message""" @@ -125,6 +163,10 @@ class MessageUpdate(BaseModel): return data +class BaseMessage(OrmMetadataBase): + __id_prefix__ = "message" + + class Message(BaseMessage): """ Letta's internal representation of a message. Includes methods to convert to/from LLM provider formats. diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index 853b8096..dcdddd20 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -29,6 +29,7 @@ from letta.schemas.letta_message_content import ( create_letta_user_message_content_union_schema, ) from letta.schemas.letta_ping import create_letta_ping_schema +from letta.schemas.message import create_message_create_union_schema from letta.server.constants import REST_DEFAULT_PORT from letta.server.db import db_registry @@ -65,6 +66,7 @@ def generate_openapi_schema(app: FastAPI): letta_docs["paths"] = {k: v for k, v in letta_docs["paths"].items() if not k.startswith("/openai")} letta_docs["info"]["title"] = "Letta API" letta_docs["components"]["schemas"]["LettaMessageUnion"] = create_letta_message_union_schema() + letta_docs["components"]["schemas"]["MessageCreateUnion"] = create_message_create_union_schema() letta_docs["components"]["schemas"]["LettaMessageContentUnion"] = create_letta_message_content_union_schema() letta_docs["components"]["schemas"]["LettaAssistantMessageContentUnion"] = create_letta_assistant_message_content_union_schema() letta_docs["components"]["schemas"]["LettaUserMessageContentUnion"] = create_letta_user_message_content_union_schema() diff --git a/letta/services/agent_serialization_manager.py b/letta/services/agent_serialization_manager.py index f279bab0..a0cca9b8 100644 --- a/letta/services/agent_serialization_manager.py +++ b/letta/services/agent_serialization_manager.py @@ -668,7 +668,7 @@ class AgentSerializationManager: messages = [] for message_schema in agent_schema.messages: # Convert MessageSchema back to Message, setting agent_id to new DB ID - message_data = message_schema.model_dump(exclude={"id"}) + message_data = message_schema.model_dump(exclude={"id", "type"}) message_data["agent_id"] = agent_db_id # Remap agent_id to new database ID message_obj = Message(**message_data) messages.append(message_obj)