chore: message schema api improvements (#1267)
This commit is contained in:
@@ -39,7 +39,8 @@ from letta.orm.enums import ToolType
|
||||
from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageContentType, MessageRole
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.memory import ContextWindowOverview, Memory
|
||||
from letta.schemas.message import Message, ToolReturn
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
@@ -165,7 +166,7 @@ class Agent(BaseAgent):
|
||||
in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_state.id, actor=self.user)
|
||||
for i in range(len(in_context_messages) - 1, -1, -1):
|
||||
msg = in_context_messages[i]
|
||||
if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and msg.content[0].type == MessageContentType.text:
|
||||
if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and isinstance(msg.content[0], TextContent):
|
||||
text_content = msg.content[0].text
|
||||
try:
|
||||
response_json = json.loads(text_content)
|
||||
@@ -1210,7 +1211,7 @@ class Agent(BaseAgent):
|
||||
and in_context_messages[1].role == MessageRole.user
|
||||
and in_context_messages[1].content
|
||||
and len(in_context_messages[1].content) == 1
|
||||
and in_context_messages[1].content[0].type == MessageContentType.text
|
||||
and isinstance(in_context_messages[1].content[0], TextContent)
|
||||
# TODO remove hardcoding
|
||||
and "The following is a summary of the previous " in in_context_messages[1].content[0].text
|
||||
):
|
||||
|
||||
@@ -5,7 +5,8 @@ import openai
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import TextContent, UserMessage
|
||||
from letta.schemas.letta_message import UserMessage
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
from letta.schemas.user import User
|
||||
|
||||
@@ -4,7 +4,7 @@ from letta.agent import Agent, AgentState
|
||||
from letta.interface import AgentInterface
|
||||
from letta.orm import User
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.letta_message import TextContent
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
@@ -7,7 +7,8 @@ import requests
|
||||
from letta.constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
from letta.schemas.message import Message, TextContent
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message
|
||||
|
||||
|
||||
def message_chatgpt(self, message: str):
|
||||
|
||||
@@ -5,8 +5,9 @@ from letta.llm_api.llm_api_tools import create
|
||||
from letta.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message, TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.settings import summarizer_settings
|
||||
from letta.utils import count_tokens, printd
|
||||
|
||||
|
||||
@@ -7,8 +7,8 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from letta.orm.custom_columns import ToolCallColumn, ToolReturnColumn
|
||||
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.letta_message_content import TextContent as PydanticTextContent
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import TextContent as PydanticTextContent
|
||||
from letta.schemas.message import ToolReturn
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import List, Optional
|
||||
from letta.agent import Agent, AgentState
|
||||
from letta.interface import AgentInterface
|
||||
from letta.orm import User
|
||||
from letta.schemas.letta_message import TextContent
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
@@ -9,10 +9,6 @@ class MessageRole(str, Enum):
|
||||
system = "system"
|
||||
|
||||
|
||||
class MessageContentType(str, Enum):
|
||||
text = "text"
|
||||
|
||||
|
||||
class OptionState(str, Enum):
|
||||
"""Useful for kwargs that are bool + default option"""
|
||||
|
||||
|
||||
@@ -4,87 +4,83 @@ from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_serializer, field_validator
|
||||
|
||||
from letta.schemas.enums import MessageContentType
|
||||
from letta.schemas.letta_message_content import LettaMessageContentUnion, get_letta_message_content_union_str_json_schema
|
||||
|
||||
# Letta API style responses (intended to be easier to use vs getting true Message types)
|
||||
# ---------------------------
|
||||
# Letta API Messaging Schemas
|
||||
# ---------------------------
|
||||
|
||||
|
||||
class LettaMessage(BaseModel):
|
||||
"""
|
||||
Base class for simplified Letta message response type. This is intended to be used for developers who want the internal monologue, tool calls, and tool returns in a simplified format that does not include additional information other than the content and timestamp.
|
||||
Base class for simplified Letta message response type. This is intended to be used for developers
|
||||
who want the internal monologue, tool calls, and tool returns in a simplified format that does not
|
||||
include additional information other than the content and timestamp.
|
||||
|
||||
Attributes:
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
|
||||
"""
|
||||
|
||||
# NOTE: use Pydantic's discriminated unions feature: https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
|
||||
# see `message_type` attribute
|
||||
|
||||
id: str
|
||||
date: datetime
|
||||
|
||||
@field_serializer("date")
|
||||
def serialize_datetime(self, dt: datetime, _info):
|
||||
"""
|
||||
Remove microseconds since it seems like we're inconsistent with getting them
|
||||
TODO: figure out why we don't always get microseconds (get_utc_time() does)
|
||||
"""
|
||||
if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None:
|
||||
dt = dt.replace(tzinfo=timezone.utc)
|
||||
# Remove microseconds since it seems like we're inconsistent with getting them
|
||||
# TODO figure out why we don't always get microseconds (get_utc_time() does)
|
||||
return dt.isoformat(timespec="seconds")
|
||||
|
||||
|
||||
class MessageContent(BaseModel):
|
||||
type: MessageContentType = Field(..., description="The type of the message.")
|
||||
|
||||
|
||||
class TextContent(MessageContent):
|
||||
type: Literal[MessageContentType.text] = Field(MessageContentType.text, description="The type of the message.")
|
||||
text: str = Field(..., description="The text content of the message.")
|
||||
|
||||
|
||||
MessageContentUnion = Annotated[
|
||||
Union[TextContent],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class SystemMessage(LettaMessage):
|
||||
"""
|
||||
A message generated by the system. Never streamed back on a response, only used for cursor pagination.
|
||||
|
||||
Attributes:
|
||||
content (Union[str, List[MessageContentUnion]]): The message content sent by the user (can be a string or an array of content parts)
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
content (Union[str, List[LettaMessageContentUnion]]): The message content sent by the system (can be a string or an array of content parts)
|
||||
"""
|
||||
|
||||
message_type: Literal["system_message"] = "system_message"
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
content: Union[str, List[LettaMessageContentUnion]] = Field(
|
||||
...,
|
||||
description="The message content sent by the system (can be a string or an array of content parts)",
|
||||
json_schema_extra=get_letta_message_content_union_str_json_schema(),
|
||||
)
|
||||
|
||||
|
||||
class UserMessage(LettaMessage):
|
||||
"""
|
||||
A message sent by the user. Never streamed back on a response, only used for cursor pagination.
|
||||
|
||||
Attributes:
|
||||
content (Union[str, List[MessageContentUnion]]): The message content sent by the user (can be a string or an array of content parts)
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
content (Union[str, List[LettaMessageContentUnion]]): The message content sent by the user (can be a string or an array of content parts)
|
||||
"""
|
||||
|
||||
message_type: Literal["user_message"] = "user_message"
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
content: Union[str, List[LettaMessageContentUnion]] = Field(
|
||||
...,
|
||||
description="The message content sent by the user (can be a string or an array of content parts)",
|
||||
json_schema_extra=get_letta_message_content_union_str_json_schema(),
|
||||
)
|
||||
|
||||
|
||||
class ReasoningMessage(LettaMessage):
|
||||
"""
|
||||
Representation of an agent's internal reasoning.
|
||||
|
||||
Attributes:
|
||||
reasoning (str): The internal reasoning of the agent
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
reasoning (str): The internal reasoning of the agent
|
||||
"""
|
||||
|
||||
message_type: Literal["reasoning_message"] = "reasoning_message"
|
||||
@@ -92,21 +88,21 @@ class ReasoningMessage(LettaMessage):
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
|
||||
name: str
|
||||
arguments: str
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class ToolCallDelta(BaseModel):
|
||||
|
||||
name: Optional[str]
|
||||
arguments: Optional[str]
|
||||
tool_call_id: Optional[str]
|
||||
|
||||
# NOTE: this is a workaround to exclude None values from the JSON dump,
|
||||
# since the OpenAI style of returning chunks doesn't include keys with null values
|
||||
def model_dump(self, *args, **kwargs):
|
||||
"""
|
||||
This is a workaround to exclude None values from the JSON dump since the
|
||||
OpenAI style of returning chunks doesn't include keys with null values.
|
||||
"""
|
||||
kwargs["exclude_none"] = True
|
||||
return super().model_dump(*args, **kwargs)
|
||||
|
||||
@@ -118,17 +114,19 @@ class ToolCallMessage(LettaMessage):
|
||||
"""
|
||||
A message representing a request to call a tool (generated by the LLM to trigger tool execution).
|
||||
|
||||
Attributes:
|
||||
tool_call (Union[ToolCall, ToolCallDelta]): The tool call
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
tool_call (Union[ToolCall, ToolCallDelta]): The tool call
|
||||
"""
|
||||
|
||||
message_type: Literal["tool_call_message"] = "tool_call_message"
|
||||
tool_call: Union[ToolCall, ToolCallDelta]
|
||||
|
||||
# NOTE: this is required for the ToolCallDelta exclude_none to work correctly
|
||||
def model_dump(self, *args, **kwargs):
|
||||
"""
|
||||
Handling for the ToolCallDelta exclude_none to work correctly
|
||||
"""
|
||||
kwargs["exclude_none"] = True
|
||||
data = super().model_dump(*args, **kwargs)
|
||||
if isinstance(data["tool_call"], dict):
|
||||
@@ -141,12 +139,14 @@ class ToolCallMessage(LettaMessage):
|
||||
ToolCall: lambda v: v.model_dump(exclude_none=True),
|
||||
}
|
||||
|
||||
# NOTE: this is required to cast dicts into ToolCallMessage objects
|
||||
# Without this extra validator, Pydantic will throw an error if 'name' or 'arguments' are None
|
||||
# (instead of properly casting to ToolCallDelta instead of ToolCall)
|
||||
@field_validator("tool_call", mode="before")
|
||||
@classmethod
|
||||
def validate_tool_call(cls, v):
|
||||
"""
|
||||
Casts dicts into ToolCallMessage objects. Without this extra validator, Pydantic will throw
|
||||
an error if 'name' or 'arguments' are None instead of properly casting to ToolCallDelta
|
||||
instead of ToolCall.
|
||||
"""
|
||||
if isinstance(v, dict):
|
||||
if "name" in v and "arguments" in v and "tool_call_id" in v:
|
||||
return ToolCall(name=v["name"], arguments=v["arguments"], tool_call_id=v["tool_call_id"])
|
||||
@@ -161,11 +161,11 @@ class ToolReturnMessage(LettaMessage):
|
||||
"""
|
||||
A message representing the return value of a tool call (generated by Letta executing the requested tool).
|
||||
|
||||
Attributes:
|
||||
tool_return (str): The return value of the tool
|
||||
status (Literal["success", "error"]): The status of the tool call
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
tool_return (str): The return value of the tool
|
||||
status (Literal["success", "error"]): The status of the tool call
|
||||
tool_call_id (str): A unique identifier for the tool call that generated this message
|
||||
stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the tool invocation
|
||||
stderr (Optional[List(str)]): Captured stderr from the tool invocation
|
||||
@@ -179,89 +179,31 @@ class ToolReturnMessage(LettaMessage):
|
||||
stderr: Optional[List[str]] = None
|
||||
|
||||
|
||||
# Legacy Letta API had an additional type "assistant_message" and the "function_call" was a formatted string
|
||||
|
||||
|
||||
class AssistantMessage(LettaMessage):
|
||||
"""
|
||||
A message sent by the LLM in response to user input. Used in the LLM context.
|
||||
|
||||
Args:
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
content (Union[str, List[LettaMessageContentUnion]]): The message content sent by the agent (can be a string or an array of content parts)
|
||||
"""
|
||||
|
||||
message_type: Literal["assistant_message"] = "assistant_message"
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
|
||||
|
||||
class LegacyFunctionCallMessage(LettaMessage):
|
||||
function_call: str
|
||||
|
||||
|
||||
class LegacyFunctionReturn(LettaMessage):
|
||||
"""
|
||||
A message representing the return value of a function call (generated by Letta executing the requested function).
|
||||
|
||||
Attributes:
|
||||
function_return (str): The return value of the function
|
||||
status (Literal["success", "error"]): The status of the function call
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
function_call_id (str): A unique identifier for the function call that generated this message
|
||||
stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the function invocation
|
||||
stderr (Optional[List(str)]): Captured stderr from the function invocation
|
||||
"""
|
||||
|
||||
message_type: Literal["function_return"] = "function_return"
|
||||
function_return: str
|
||||
status: Literal["success", "error"]
|
||||
function_call_id: str
|
||||
stdout: Optional[List[str]] = None
|
||||
stderr: Optional[List[str]] = None
|
||||
|
||||
|
||||
class LegacyInternalMonologue(LettaMessage):
|
||||
"""
|
||||
Representation of an agent's internal monologue.
|
||||
|
||||
Attributes:
|
||||
internal_monologue (str): The internal monologue of the agent
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
"""
|
||||
|
||||
message_type: Literal["internal_monologue"] = "internal_monologue"
|
||||
internal_monologue: str
|
||||
|
||||
|
||||
LegacyLettaMessage = Union[LegacyInternalMonologue, AssistantMessage, LegacyFunctionCallMessage, LegacyFunctionReturn]
|
||||
content: Union[str, List[LettaMessageContentUnion]] = Field(
|
||||
...,
|
||||
description="The message content sent by the agent (can be a string or an array of content parts)",
|
||||
json_schema_extra=get_letta_message_content_union_str_json_schema(),
|
||||
)
|
||||
|
||||
|
||||
# NOTE: use Pydantic's discriminated unions feature: https://docs.pydantic.dev/latest/concepts/unions/#discriminated-unions
|
||||
LettaMessageUnion = Annotated[
|
||||
Union[SystemMessage, UserMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage, AssistantMessage],
|
||||
Field(discriminator="message_type"),
|
||||
]
|
||||
|
||||
|
||||
class UpdateSystemMessage(BaseModel):
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
message_type: Literal["system_message"] = "system_message"
|
||||
|
||||
|
||||
class UpdateUserMessage(BaseModel):
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
message_type: Literal["user_message"] = "user_message"
|
||||
|
||||
|
||||
class UpdateReasoningMessage(BaseModel):
|
||||
reasoning: Union[str, List[MessageContentUnion]]
|
||||
message_type: Literal["reasoning_message"] = "reasoning_message"
|
||||
|
||||
|
||||
class UpdateAssistantMessage(BaseModel):
|
||||
content: Union[str, List[MessageContentUnion]]
|
||||
message_type: Literal["assistant_message"] = "assistant_message"
|
||||
|
||||
|
||||
LettaMessageUpdateUnion = Annotated[
|
||||
Union[UpdateSystemMessage, UpdateUserMessage, UpdateReasoningMessage, UpdateAssistantMessage],
|
||||
Field(discriminator="message_type"),
|
||||
]
|
||||
|
||||
|
||||
def create_letta_message_union_schema():
|
||||
return {
|
||||
"oneOf": [
|
||||
@@ -284,3 +226,94 @@ def create_letta_message_union_schema():
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Message Update API Schemas
|
||||
# --------------------------
|
||||
|
||||
|
||||
class UpdateSystemMessage(BaseModel):
|
||||
message_type: Literal["system_message"] = "system_message"
|
||||
content: Union[str, List[LettaMessageContentUnion]] = Field(
|
||||
...,
|
||||
description="The message content sent by the system (can be a string or an array of content parts)",
|
||||
json_schema_extra=get_letta_message_content_union_str_json_schema(),
|
||||
)
|
||||
|
||||
|
||||
class UpdateUserMessage(BaseModel):
|
||||
message_type: Literal["user_message"] = "user_message"
|
||||
content: Union[str, List[LettaMessageContentUnion]] = Field(
|
||||
...,
|
||||
description="The message content sent by the user (can be a string or an array of content parts)",
|
||||
json_schema_extra=get_letta_message_content_union_str_json_schema(),
|
||||
)
|
||||
|
||||
|
||||
class UpdateReasoningMessage(BaseModel):
|
||||
reasoning: str
|
||||
message_type: Literal["reasoning_message"] = "reasoning_message"
|
||||
|
||||
|
||||
class UpdateAssistantMessage(BaseModel):
|
||||
message_type: Literal["assistant_message"] = "assistant_message"
|
||||
content: Union[str, List[LettaMessageContentUnion]] = Field(
|
||||
...,
|
||||
description="The message content sent by the assistant (can be a string or an array of content parts)",
|
||||
json_schema_extra=get_letta_message_content_union_str_json_schema(),
|
||||
)
|
||||
|
||||
|
||||
LettaMessageUpdateUnion = Annotated[
|
||||
Union[UpdateSystemMessage, UpdateUserMessage, UpdateReasoningMessage, UpdateAssistantMessage],
|
||||
Field(discriminator="message_type"),
|
||||
]
|
||||
|
||||
|
||||
# --------------------------
|
||||
# Deprecated Message Schemas
|
||||
# --------------------------
|
||||
|
||||
|
||||
class LegacyFunctionCallMessage(LettaMessage):
|
||||
function_call: str
|
||||
|
||||
|
||||
class LegacyFunctionReturn(LettaMessage):
|
||||
"""
|
||||
A message representing the return value of a function call (generated by Letta executing the requested function).
|
||||
|
||||
Args:
|
||||
function_return (str): The return value of the function
|
||||
status (Literal["success", "error"]): The status of the function call
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
function_call_id (str): A unique identifier for the function call that generated this message
|
||||
stdout (Optional[List(str)]): Captured stdout (e.g. prints, logs) from the function invocation
|
||||
stderr (Optional[List(str)]): Captured stderr from the function invocation
|
||||
"""
|
||||
|
||||
message_type: Literal["function_return"] = "function_return"
|
||||
function_return: str
|
||||
status: Literal["success", "error"]
|
||||
function_call_id: str
|
||||
stdout: Optional[List[str]] = None
|
||||
stderr: Optional[List[str]] = None
|
||||
|
||||
|
||||
class LegacyInternalMonologue(LettaMessage):
|
||||
"""
|
||||
Representation of an agent's internal monologue.
|
||||
|
||||
Args:
|
||||
internal_monologue (str): The internal monologue of the agent
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
"""
|
||||
|
||||
message_type: Literal["internal_monologue"] = "internal_monologue"
|
||||
internal_monologue: str
|
||||
|
||||
|
||||
LegacyLettaMessage = Union[LegacyInternalMonologue, AssistantMessage, LegacyFunctionCallMessage, LegacyFunctionReturn]
|
||||
|
||||
56
letta/schemas/letta_message_content.py
Normal file
56
letta/schemas/letta_message_content.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MessageContentType(str, Enum):
|
||||
text = "text"
|
||||
|
||||
|
||||
class MessageContent(BaseModel):
|
||||
type: MessageContentType = Field(..., description="The type of the message.")
|
||||
|
||||
|
||||
# -------------------------------
|
||||
# Multi-Modal Content Types
|
||||
# -------------------------------
|
||||
|
||||
|
||||
class TextContent(MessageContent):
|
||||
type: Literal[MessageContentType.text] = Field(MessageContentType.text, description="The type of the message.")
|
||||
text: str = Field(..., description="The text content of the message.")
|
||||
|
||||
|
||||
LettaMessageContentUnion = Annotated[
|
||||
Union[TextContent],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
def create_letta_message_content_union_schema():
|
||||
return {
|
||||
"oneOf": [
|
||||
{"$ref": "#/components/schemas/TextContent"},
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"text": "#/components/schemas/TextContent",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_letta_message_content_union_str_json_schema():
|
||||
return {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/components/schemas/LettaMessageContentUnion",
|
||||
},
|
||||
},
|
||||
{"type": "string"},
|
||||
],
|
||||
}
|
||||
@@ -15,20 +15,19 @@ from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, TO
|
||||
from letta.helpers.datetime_helpers import get_utc_time, is_utc_datetime
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from letta.schemas.enums import MessageContentType, MessageRole
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.letta_message import (
|
||||
AssistantMessage,
|
||||
LettaMessage,
|
||||
MessageContentUnion,
|
||||
ReasoningMessage,
|
||||
SystemMessage,
|
||||
TextContent,
|
||||
ToolCall,
|
||||
ToolCallMessage,
|
||||
ToolReturnMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from letta.schemas.letta_message_content import LettaMessageContentUnion, TextContent
|
||||
from letta.system import unpack_message
|
||||
|
||||
|
||||
@@ -66,7 +65,7 @@ class MessageCreate(BaseModel):
|
||||
MessageRole.user,
|
||||
MessageRole.system,
|
||||
] = Field(..., description="The role of the participant.")
|
||||
content: Union[str, List[MessageContentUnion]] = Field(..., description="The content of the message.")
|
||||
content: Union[str, List[LettaMessageContentUnion]] = Field(..., description="The content of the message.")
|
||||
name: Optional[str] = Field(None, description="The name of the participant.")
|
||||
|
||||
|
||||
@@ -74,7 +73,7 @@ class MessageUpdate(BaseModel):
|
||||
"""Request to update a message"""
|
||||
|
||||
role: Optional[MessageRole] = Field(None, description="The role of the participant.")
|
||||
content: Optional[Union[str, List[MessageContentUnion]]] = Field(None, description="The content of the message.")
|
||||
content: Optional[Union[str, List[LettaMessageContentUnion]]] = Field(None, description="The content of the message.")
|
||||
# NOTE: probably doesn't make sense to allow remapping user_id or agent_id (vs creating a new message)
|
||||
# user_id: Optional[str] = Field(None, description="The unique identifier of the user.")
|
||||
# agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
|
||||
@@ -119,7 +118,7 @@ class Message(BaseMessage):
|
||||
|
||||
id: str = BaseMessage.generate_id_field()
|
||||
role: MessageRole = Field(..., description="The role of the participant.")
|
||||
content: Optional[List[MessageContentUnion]] = Field(None, description="The content of the message.")
|
||||
content: Optional[List[LettaMessageContentUnion]] = Field(None, description="The content of the message.")
|
||||
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization.")
|
||||
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
|
||||
model: Optional[str] = Field(None, description="The model used to make the function call.")
|
||||
@@ -215,7 +214,7 @@ class Message(BaseMessage):
|
||||
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
) -> List[LettaMessage]:
|
||||
"""Convert message object (in DB format) to the style used by the original Letta API"""
|
||||
if self.content and len(self.content) == 1 and self.content[0].type == MessageContentType.text:
|
||||
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
|
||||
text_content = self.content[0].text
|
||||
else:
|
||||
text_content = None
|
||||
@@ -486,7 +485,7 @@ class Message(BaseMessage):
|
||||
"""Go from Message class to ChatCompletion message object"""
|
||||
|
||||
# TODO change to pydantic casting, eg `return SystemMessageModel(self)`
|
||||
if self.content and len(self.content) == 1 and self.content[0].type == MessageContentType.text:
|
||||
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
|
||||
text_content = self.content[0].text
|
||||
else:
|
||||
text_content = None
|
||||
@@ -561,7 +560,7 @@ class Message(BaseMessage):
|
||||
Args:
|
||||
inner_thoughts_xml_tag (str): The XML tag to wrap around inner thoughts
|
||||
"""
|
||||
if self.content and len(self.content) == 1 and self.content[0].type == MessageContentType.text:
|
||||
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
|
||||
text_content = self.content[0].text
|
||||
else:
|
||||
text_content = None
|
||||
@@ -656,7 +655,7 @@ class Message(BaseMessage):
|
||||
# type Content: https://ai.google.dev/api/rest/v1/Content / https://ai.google.dev/api/rest/v1beta/Content
|
||||
# parts[]: Part
|
||||
# role: str ('user' or 'model')
|
||||
if self.content and len(self.content) == 1 and self.content[0].type == MessageContentType.text:
|
||||
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
|
||||
text_content = self.content[0].text
|
||||
else:
|
||||
text_content = None
|
||||
@@ -782,7 +781,7 @@ class Message(BaseMessage):
|
||||
|
||||
# TODO: update this prompt style once guidance from Cohere on
|
||||
# embedded function calls in multi-turn conversation become more clear
|
||||
if self.content and len(self.content) == 1 and self.content[0].type == MessageContentType.text:
|
||||
if self.content and len(self.content) == 1 and isinstance(self.content[0], TextContent):
|
||||
text_content = self.content[0].text
|
||||
else:
|
||||
text_content = None
|
||||
|
||||
@@ -17,6 +17,7 @@ from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaU
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
|
||||
from letta.schemas.letta_message import create_letta_message_union_schema
|
||||
from letta.schemas.letta_message_content import create_letta_message_content_union_schema
|
||||
from letta.server.constants import REST_DEFAULT_PORT
|
||||
|
||||
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
|
||||
@@ -68,6 +69,7 @@ def generate_openapi_schema(app: FastAPI):
|
||||
letta_docs["paths"] = {k: v for k, v in letta_docs["paths"].items() if not k.startswith("/openai")}
|
||||
letta_docs["info"]["title"] = "Letta API"
|
||||
letta_docs["components"]["schemas"]["LettaMessageUnion"] = create_letta_message_union_schema()
|
||||
letta_docs["components"]["schemas"]["LettaMessageContentUnion"] = create_letta_message_content_union_schema()
|
||||
for name, docs in [
|
||||
(
|
||||
"letta",
|
||||
|
||||
@@ -530,7 +530,7 @@ def list_messages(
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/messages/{message_id}", response_model=LettaMessageUpdateUnion, operation_id="modify_message")
|
||||
@router.patch("/{agent_id}/messages/{message_id}", response_model=LettaMessageUnion, operation_id="modify_message")
|
||||
def modify_message(
|
||||
agent_id: str,
|
||||
message_id: str,
|
||||
|
||||
@@ -18,7 +18,7 @@ from letta.errors import ContextWindowExceededError, RateLimitExceededError
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import TextContent
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
|
||||
@@ -49,10 +49,11 @@ from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
|
||||
from letta.schemas.group import Group, ManagerType
|
||||
from letta.schemas.job import Job, JobUpdate
|
||||
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, ToolReturnMessage
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ArchivalMemorySummary, ContextWindowOverview, Memory, RecallMemorySummary
|
||||
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate, TextContent
|
||||
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage, PassageUpdate
|
||||
from letta.schemas.providers import (
|
||||
|
||||
@@ -11,8 +11,9 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.prompts import gpt_system
|
||||
from letta.schemas.agent import AgentState, AgentType
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message, MessageCreate, TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.tool_rule import ToolRule
|
||||
from letta.schemas.user import User
|
||||
|
||||
@@ -6,7 +6,7 @@ from letta.functions.function_sets.multi_agent import send_message_to_all_agents
|
||||
from letta.interface import AgentInterface
|
||||
from letta.orm import User
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.letta_message import TextContent
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
@@ -13,8 +13,9 @@ from letta.errors import ContextWindowExceededError
|
||||
from letta.llm_api.helpers import calculate_summarizer_cutoff
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message, TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.settings import summarizer_settings
|
||||
from letta.streaming_interface import StreamingRefreshCLIInterface
|
||||
from tests.helpers.endpoints_helper import EMBEDDING_CONFIG_PATH
|
||||
|
||||
@@ -6,7 +6,7 @@ import pytest
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import TextContent
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.services.summarizer.enums import SummarizationMode
|
||||
from letta.services.summarizer.summarizer import Summarizer
|
||||
|
||||
Reference in New Issue
Block a user