chore: add new primitive types and replace id_prefixes everywhere (#5749)

add new primitive types and replace id_prefixes everywhere
This commit is contained in:
Kian Jones
2025-10-28 14:54:51 -07:00
committed by Caren Thomas
parent f0de0b5812
commit 3f78c93be5
14 changed files with 78 additions and 27 deletions

View File

@@ -7,7 +7,7 @@ from pydantic import BaseModel, Field
from letta.helpers.datetime_helpers import get_utc_time
from letta.schemas.agent import AgentState, CreateAgent
from letta.schemas.block import Block, CreateBlock
from letta.schemas.enums import MessageRole
from letta.schemas.enums import MessageRole, PrimitiveType
from letta.schemas.file import FileAgent, FileAgentBase, FileMetadata, FileMetadataBase
from letta.schemas.group import Group, GroupCreate
from letta.schemas.letta_message import ApprovalReturn
@@ -42,7 +42,7 @@ class ImportResult:
class MessageSchema(MessageCreate):
"""Message with human-readable ID for agent file"""
__id_prefix__ = "message"
__id_prefix__ = PrimitiveType.MESSAGE.value
id: str = Field(..., description="Human-readable identifier for this message in the file")
# Override the role field to accept all message roles, not just user/system/assistant
@@ -96,7 +96,7 @@ class MessageSchema(MessageCreate):
class FileAgentSchema(FileAgentBase):
"""File-Agent relationship with human-readable ID for agent file"""
__id_prefix__ = "file_agent"
__id_prefix__ = PrimitiveType.FILE_AGENT.value
id: str = Field(..., description="Human-readable identifier for this file-agent relationship in the file")
@classmethod
@@ -120,7 +120,7 @@ class FileAgentSchema(FileAgentBase):
class AgentSchema(CreateAgent):
"""Agent with human-readable ID for agent file"""
__id_prefix__ = "agent"
__id_prefix__ = PrimitiveType.AGENT.value
id: str = Field(..., description="Human-readable identifier for this agent in the file")
in_context_message_ids: List[str] = Field(
default_factory=list, description="List of message IDs that are currently in the agent's context"
@@ -198,7 +198,7 @@ class AgentSchema(CreateAgent):
class GroupSchema(GroupCreate):
"""Group with human-readable ID for agent file"""
__id_prefix__ = "group"
__id_prefix__ = PrimitiveType.GROUP.value
id: str = Field(..., description="Human-readable identifier for this group in the file")
@classmethod
@@ -220,7 +220,7 @@ class GroupSchema(GroupCreate):
class BlockSchema(CreateBlock):
"""Block with human-readable ID for agent file"""
__id_prefix__ = "block"
__id_prefix__ = PrimitiveType.BLOCK.value
id: str = Field(..., description="Human-readable identifier for this block in the file")
@classmethod
@@ -246,7 +246,7 @@ class BlockSchema(CreateBlock):
class FileSchema(FileMetadataBase):
"""File with human-readable ID for agent file"""
__id_prefix__ = "file"
__id_prefix__ = PrimitiveType.FILE.value
id: str = Field(..., description="Human-readable identifier for this file in the file")
@classmethod
@@ -276,7 +276,7 @@ class FileSchema(FileMetadataBase):
class SourceSchema(SourceCreate):
"""Source with human-readable ID for agent file"""
__id_prefix__ = "source"
__id_prefix__ = PrimitiveType.SOURCE.value
id: str = Field(..., description="Human-readable identifier for this source in the file")
@classmethod
@@ -299,7 +299,7 @@ class SourceSchema(SourceCreate):
class ToolSchema(Tool):
"""Tool with human-readable ID for agent file"""
__id_prefix__ = "tool"
__id_prefix__ = PrimitiveType.TOOL.value
id: str = Field(..., description="Human-readable identifier for this tool in the file")
@classmethod
@@ -311,7 +311,7 @@ class ToolSchema(Tool):
class MCPServerSchema(BaseModel):
"""MCP server schema for agent files with remapped ID."""
__id_prefix__ = "mcp_server"
__id_prefix__ = PrimitiveType.MCP_SERVER.value
id: str = Field(..., description="Human-readable MCP server ID")
server_type: str

View File

@@ -26,6 +26,27 @@ class PrimitiveType(str, Enum):
STEP = "step"
IDENTITY = "identity"
# Infrastructure types
MCP_SERVER = "mcp_server"
MCP_OAUTH = "mcp-oauth"
FILE_AGENT = "file_agent"
# Configuration types
SANDBOX_ENV = "sandbox-env"
AGENT_ENV = "agent-env"
# Core entity types
USER = "user"
ORGANIZATION = "org"
TOOL_RULE = "tool_rule"
# Batch processing types
BATCH_ITEM = "batch_item"
BATCH_REQUEST = "batch_req"
# Telemetry types
PROVIDER_TRACE = "provider_trace"
class ProviderType(str, Enum):
anthropic = "anthropic"

View File

@@ -2,6 +2,7 @@ from typing import Optional
from pydantic import Field
from letta.schemas.enums import PrimitiveType
from letta.schemas.letta_base import LettaBase, OrmMetadataBase
from letta.schemas.secret import Secret
from letta.settings import settings
@@ -52,7 +53,7 @@ class EnvironmentVariableUpdateBase(LettaBase):
# Environment Variable
class SandboxEnvironmentVariableBase(EnvironmentVariableBase):
__id_prefix__ = "sandbox-env"
__id_prefix__ = PrimitiveType.SANDBOX_ENV.value
sandbox_config_id: str = Field(..., description="The ID of the sandbox config this environment variable belongs to.")
@@ -70,7 +71,7 @@ class SandboxEnvironmentVariableUpdate(EnvironmentVariableUpdateBase):
# Agent-Specific Environment Variable
class AgentEnvironmentVariableBase(EnvironmentVariableBase):
__id_prefix__ = "agent-env"
__id_prefix__ = PrimitiveType.AGENT_ENV.value
agent_id: str = Field(..., description="The ID of the agent this environment variable belongs to.")

View File

@@ -5,7 +5,7 @@ from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndi
from pydantic import BaseModel, Field
from letta.helpers import ToolRulesSolver
from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
from letta.schemas.enums import AgentStepStatus, JobStatus, PrimitiveType, ProviderType
from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.llm_config import LLMConfig
@@ -16,7 +16,7 @@ class AgentStepState(BaseModel):
class LLMBatchItemBase(OrmMetadataBase, validate_assignment=True):
__id_prefix__ = "batch_item"
__id_prefix__ = PrimitiveType.BATCH_ITEM.value
class LLMBatchItem(LLMBatchItemBase, validate_assignment=True):
@@ -47,7 +47,7 @@ class LLMBatchJob(OrmMetadataBase, validate_assignment=True):
Each job corresponds to one API call that sends multiple messages to the LLM provider, and aggregates responses across all agent submissions.
"""
__id_prefix__ = "batch_req"
__id_prefix__ = PrimitiveType.BATCH_REQUEST.value
id: Optional[str] = Field(None, description="The id of the batch job. Assigned by the database.")
status: JobStatus = Field(..., description="The current status of the batch (e.g., created, in_progress, done).")

View File

@@ -13,13 +13,14 @@ from letta.functions.mcp_client.types import (
StreamableHTTPServerConfig,
)
from letta.orm.mcp_oauth import OAuthSessionStatus
from letta.schemas.enums import PrimitiveType
from letta.schemas.letta_base import LettaBase
from letta.schemas.secret import Secret
from letta.settings import settings
class BaseMCPServer(LettaBase):
__id_prefix__ = "mcp_server"
__id_prefix__ = PrimitiveType.MCP_SERVER.value
class MCPServer(BaseMCPServer):
@@ -178,7 +179,7 @@ UpdateMCPServer = Union[UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamab
# OAuth-related schemas
class BaseMCPOAuth(LettaBase):
__id_prefix__ = "mcp-oauth"
__id_prefix__ = PrimitiveType.MCP_OAUTH.value
class MCPOAuthSession(BaseMCPOAuth):

View File

@@ -13,12 +13,13 @@ from letta.functions.mcp_client.types import (
StreamableHTTPServerConfig,
)
from letta.orm.mcp_oauth import OAuthSessionStatus
from letta.schemas.enums import PrimitiveType
from letta.schemas.letta_base import LettaBase
from letta.schemas.secret import Secret
class BaseMCPServer(LettaBase):
__id_prefix__ = "mcp_server"
__id_prefix__ = PrimitiveType.MCP_SERVER.value
# Create Schemas (for POST requests)
@@ -101,7 +102,7 @@ UpdateMCPServerUnion = Union[UpdateStdioMCPServer, UpdateSSEMCPServer, UpdateStr
# OAuth-related schemas
class BaseMCPOAuth(LettaBase):
__id_prefix__ = "mcp-oauth"
__id_prefix__ = PrimitiveType.MCP_OAUTH.value
class MCPOAuthSession(BaseMCPOAuth):

View File

@@ -4,12 +4,13 @@ from typing import Optional
from pydantic import Field
from letta.helpers.datetime_helpers import get_utc_time
from letta.schemas.enums import PrimitiveType
from letta.schemas.letta_base import LettaBase
from letta.utils import create_random_username
class OrganizationBase(LettaBase):
__id_prefix__ = "org"
__id_prefix__ = PrimitiveType.ORGANIZATION.value
class Organization(OrganizationBase):

View File

@@ -6,11 +6,12 @@ from pydantic import Field, field_validator
from letta.constants import MAX_EMBEDDING_DIM
from letta.helpers.datetime_helpers import get_utc_time
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import PrimitiveType
from letta.schemas.letta_base import OrmMetadataBase
class PassageBase(OrmMetadataBase):
__id_prefix__ = "passage"
__id_prefix__ = PrimitiveType.PASSAGE.value
is_deleted: bool = Field(False, description="Whether this passage is deleted or not.")

View File

@@ -6,11 +6,12 @@ from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from letta.helpers.datetime_helpers import get_utc_time
from letta.schemas.enums import PrimitiveType
from letta.schemas.letta_base import OrmMetadataBase
class BaseProviderTrace(OrmMetadataBase):
__id_prefix__ = "provider_trace"
__id_prefix__ = PrimitiveType.PROVIDER_TRACE.value
class ProviderTraceCreate(BaseModel):

View File

@@ -2,11 +2,12 @@ from typing import List, Optional
from pydantic import Field
from letta.schemas.enums import PrimitiveType
from letta.schemas.letta_base import LettaBase
class RunMetricsBase(LettaBase):
__id_prefix__ = "run"
__id_prefix__ = PrimitiveType.RUN.value
class RunMetrics(RunMetricsBase):

View File

@@ -2,11 +2,12 @@ from typing import Optional
from pydantic import Field
from letta.schemas.enums import PrimitiveType
from letta.schemas.letta_base import LettaBase
class StepMetricsBase(LettaBase):
__id_prefix__ = "step"
__id_prefix__ = PrimitiveType.STEP.value
class StepMetrics(StepMetricsBase):

View File

@@ -4,14 +4,14 @@ from typing import Annotated, Any, Dict, List, Literal, Optional, Set, Union
from pydantic import BaseModel, Field, field_validator, model_validator
from letta.schemas.enums import ToolRuleType
from letta.schemas.enums import PrimitiveType, ToolRuleType
from letta.schemas.letta_base import LettaBase
logger = logging.getLogger(__name__)
class BaseToolRule(LettaBase):
__id_prefix__ = "tool_rule"
__id_prefix__ = PrimitiveType.TOOL_RULE.value
tool_name: str = Field(..., description="The name of the tool. Must exist in the database for the user's organization.")
type: ToolRuleType = Field(..., description="The type of the message.")
prompt_template: Optional[str] = Field(

View File

@@ -4,11 +4,12 @@ from typing import Optional
from pydantic import Field
from letta.constants import DEFAULT_ORG_ID
from letta.schemas.enums import PrimitiveType
from letta.schemas.letta_base import LettaBase
class UserBase(LettaBase):
__id_prefix__ = "user"
__id_prefix__ = PrimitiveType.USER.value
class User(UserBase):

View File

@@ -63,6 +63,27 @@ SandboxConfigId = Annotated[str, PATH_VALIDATORS[PrimitiveType.SANDBOX_CONFIG.va
StepId = Annotated[str, PATH_VALIDATORS[PrimitiveType.STEP.value]()]
IdentityId = Annotated[str, PATH_VALIDATORS[PrimitiveType.IDENTITY.value]()]
# Infrastructure types
McpServerId = Annotated[str, PATH_VALIDATORS[PrimitiveType.MCP_SERVER.value]()]
McpOAuthId = Annotated[str, PATH_VALIDATORS[PrimitiveType.MCP_OAUTH.value]()]
FileAgentId = Annotated[str, PATH_VALIDATORS[PrimitiveType.FILE_AGENT.value]()]
# Configuration types
SandboxEnvId = Annotated[str, PATH_VALIDATORS[PrimitiveType.SANDBOX_ENV.value]()]
AgentEnvId = Annotated[str, PATH_VALIDATORS[PrimitiveType.AGENT_ENV.value]()]
# Core entity types
UserId = Annotated[str, PATH_VALIDATORS[PrimitiveType.USER.value]()]
OrganizationId = Annotated[str, PATH_VALIDATORS[PrimitiveType.ORGANIZATION.value]()]
ToolRuleId = Annotated[str, PATH_VALIDATORS[PrimitiveType.TOOL_RULE.value]()]
# Batch processing types
BatchItemId = Annotated[str, PATH_VALIDATORS[PrimitiveType.BATCH_ITEM.value]()]
BatchRequestId = Annotated[str, PATH_VALIDATORS[PrimitiveType.BATCH_REQUEST.value]()]
# Telemetry types
ProviderTraceId = Annotated[str, PATH_VALIDATORS[PrimitiveType.PROVIDER_TRACE.value]()]
def raise_on_invalid_id(param_name: str, expected_prefix: PrimitiveType):
"""