chore: add new primitive types and replace id_prefixes everywhere (#5749)
add new primitive types and replace id_prefixes everywhere
This commit is contained in:
@@ -7,7 +7,7 @@ from pydantic import BaseModel, Field
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.schemas.agent import AgentState, CreateAgent
|
||||
from letta.schemas.block import Block, CreateBlock
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.enums import MessageRole, PrimitiveType
|
||||
from letta.schemas.file import FileAgent, FileAgentBase, FileMetadata, FileMetadataBase
|
||||
from letta.schemas.group import Group, GroupCreate
|
||||
from letta.schemas.letta_message import ApprovalReturn
|
||||
@@ -42,7 +42,7 @@ class ImportResult:
|
||||
class MessageSchema(MessageCreate):
|
||||
"""Message with human-readable ID for agent file"""
|
||||
|
||||
__id_prefix__ = "message"
|
||||
__id_prefix__ = PrimitiveType.MESSAGE.value
|
||||
id: str = Field(..., description="Human-readable identifier for this message in the file")
|
||||
|
||||
# Override the role field to accept all message roles, not just user/system/assistant
|
||||
@@ -96,7 +96,7 @@ class MessageSchema(MessageCreate):
|
||||
class FileAgentSchema(FileAgentBase):
|
||||
"""File-Agent relationship with human-readable ID for agent file"""
|
||||
|
||||
__id_prefix__ = "file_agent"
|
||||
__id_prefix__ = PrimitiveType.FILE_AGENT.value
|
||||
id: str = Field(..., description="Human-readable identifier for this file-agent relationship in the file")
|
||||
|
||||
@classmethod
|
||||
@@ -120,7 +120,7 @@ class FileAgentSchema(FileAgentBase):
|
||||
class AgentSchema(CreateAgent):
|
||||
"""Agent with human-readable ID for agent file"""
|
||||
|
||||
__id_prefix__ = "agent"
|
||||
__id_prefix__ = PrimitiveType.AGENT.value
|
||||
id: str = Field(..., description="Human-readable identifier for this agent in the file")
|
||||
in_context_message_ids: List[str] = Field(
|
||||
default_factory=list, description="List of message IDs that are currently in the agent's context"
|
||||
@@ -198,7 +198,7 @@ class AgentSchema(CreateAgent):
|
||||
class GroupSchema(GroupCreate):
|
||||
"""Group with human-readable ID for agent file"""
|
||||
|
||||
__id_prefix__ = "group"
|
||||
__id_prefix__ = PrimitiveType.GROUP.value
|
||||
id: str = Field(..., description="Human-readable identifier for this group in the file")
|
||||
|
||||
@classmethod
|
||||
@@ -220,7 +220,7 @@ class GroupSchema(GroupCreate):
|
||||
class BlockSchema(CreateBlock):
|
||||
"""Block with human-readable ID for agent file"""
|
||||
|
||||
__id_prefix__ = "block"
|
||||
__id_prefix__ = PrimitiveType.BLOCK.value
|
||||
id: str = Field(..., description="Human-readable identifier for this block in the file")
|
||||
|
||||
@classmethod
|
||||
@@ -246,7 +246,7 @@ class BlockSchema(CreateBlock):
|
||||
class FileSchema(FileMetadataBase):
|
||||
"""File with human-readable ID for agent file"""
|
||||
|
||||
__id_prefix__ = "file"
|
||||
__id_prefix__ = PrimitiveType.FILE.value
|
||||
id: str = Field(..., description="Human-readable identifier for this file in the file")
|
||||
|
||||
@classmethod
|
||||
@@ -276,7 +276,7 @@ class FileSchema(FileMetadataBase):
|
||||
class SourceSchema(SourceCreate):
|
||||
"""Source with human-readable ID for agent file"""
|
||||
|
||||
__id_prefix__ = "source"
|
||||
__id_prefix__ = PrimitiveType.SOURCE.value
|
||||
id: str = Field(..., description="Human-readable identifier for this source in the file")
|
||||
|
||||
@classmethod
|
||||
@@ -299,7 +299,7 @@ class SourceSchema(SourceCreate):
|
||||
class ToolSchema(Tool):
|
||||
"""Tool with human-readable ID for agent file"""
|
||||
|
||||
__id_prefix__ = "tool"
|
||||
__id_prefix__ = PrimitiveType.TOOL.value
|
||||
id: str = Field(..., description="Human-readable identifier for this tool in the file")
|
||||
|
||||
@classmethod
|
||||
@@ -311,7 +311,7 @@ class ToolSchema(Tool):
|
||||
class MCPServerSchema(BaseModel):
|
||||
"""MCP server schema for agent files with remapped ID."""
|
||||
|
||||
__id_prefix__ = "mcp_server"
|
||||
__id_prefix__ = PrimitiveType.MCP_SERVER.value
|
||||
|
||||
id: str = Field(..., description="Human-readable MCP server ID")
|
||||
server_type: str
|
||||
|
||||
@@ -26,6 +26,27 @@ class PrimitiveType(str, Enum):
|
||||
STEP = "step"
|
||||
IDENTITY = "identity"
|
||||
|
||||
# Infrastructure types
|
||||
MCP_SERVER = "mcp_server"
|
||||
MCP_OAUTH = "mcp-oauth"
|
||||
FILE_AGENT = "file_agent"
|
||||
|
||||
# Configuration types
|
||||
SANDBOX_ENV = "sandbox-env"
|
||||
AGENT_ENV = "agent-env"
|
||||
|
||||
# Core entity types
|
||||
USER = "user"
|
||||
ORGANIZATION = "org"
|
||||
TOOL_RULE = "tool_rule"
|
||||
|
||||
# Batch processing types
|
||||
BATCH_ITEM = "batch_item"
|
||||
BATCH_REQUEST = "batch_req"
|
||||
|
||||
# Telemetry types
|
||||
PROVIDER_TRACE = "provider_trace"
|
||||
|
||||
|
||||
class ProviderType(str, Enum):
|
||||
anthropic = "anthropic"
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.letta_base import LettaBase, OrmMetadataBase
|
||||
from letta.schemas.secret import Secret
|
||||
from letta.settings import settings
|
||||
@@ -52,7 +53,7 @@ class EnvironmentVariableUpdateBase(LettaBase):
|
||||
|
||||
# Environment Variable
|
||||
class SandboxEnvironmentVariableBase(EnvironmentVariableBase):
|
||||
__id_prefix__ = "sandbox-env"
|
||||
__id_prefix__ = PrimitiveType.SANDBOX_ENV.value
|
||||
sandbox_config_id: str = Field(..., description="The ID of the sandbox config this environment variable belongs to.")
|
||||
|
||||
|
||||
@@ -70,7 +71,7 @@ class SandboxEnvironmentVariableUpdate(EnvironmentVariableUpdateBase):
|
||||
|
||||
# Agent-Specific Environment Variable
|
||||
class AgentEnvironmentVariableBase(EnvironmentVariableBase):
|
||||
__id_prefix__ = "agent-env"
|
||||
__id_prefix__ = PrimitiveType.AGENT_ENV.value
|
||||
agent_id: str = Field(..., description="The ID of the agent this environment variable belongs to.")
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndi
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus, PrimitiveType, ProviderType
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
@@ -16,7 +16,7 @@ class AgentStepState(BaseModel):
|
||||
|
||||
|
||||
class LLMBatchItemBase(OrmMetadataBase, validate_assignment=True):
|
||||
__id_prefix__ = "batch_item"
|
||||
__id_prefix__ = PrimitiveType.BATCH_ITEM.value
|
||||
|
||||
|
||||
class LLMBatchItem(LLMBatchItemBase, validate_assignment=True):
|
||||
@@ -47,7 +47,7 @@ class LLMBatchJob(OrmMetadataBase, validate_assignment=True):
|
||||
Each job corresponds to one API call that sends multiple messages to the LLM provider, and aggregates responses across all agent submissions.
|
||||
"""
|
||||
|
||||
__id_prefix__ = "batch_req"
|
||||
__id_prefix__ = PrimitiveType.BATCH_REQUEST.value
|
||||
|
||||
id: Optional[str] = Field(None, description="The id of the batch job. Assigned by the database.")
|
||||
status: JobStatus = Field(..., description="The current status of the batch (e.g., created, in_progress, done).")
|
||||
|
||||
@@ -13,13 +13,14 @@ from letta.functions.mcp_client.types import (
|
||||
StreamableHTTPServerConfig,
|
||||
)
|
||||
from letta.orm.mcp_oauth import OAuthSessionStatus
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.secret import Secret
|
||||
from letta.settings import settings
|
||||
|
||||
|
||||
class BaseMCPServer(LettaBase):
|
||||
__id_prefix__ = "mcp_server"
|
||||
__id_prefix__ = PrimitiveType.MCP_SERVER.value
|
||||
|
||||
|
||||
class MCPServer(BaseMCPServer):
|
||||
@@ -178,7 +179,7 @@ UpdateMCPServer = Union[UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamab
|
||||
|
||||
# OAuth-related schemas
|
||||
class BaseMCPOAuth(LettaBase):
|
||||
__id_prefix__ = "mcp-oauth"
|
||||
__id_prefix__ = PrimitiveType.MCP_OAUTH.value
|
||||
|
||||
|
||||
class MCPOAuthSession(BaseMCPOAuth):
|
||||
|
||||
@@ -13,12 +13,13 @@ from letta.functions.mcp_client.types import (
|
||||
StreamableHTTPServerConfig,
|
||||
)
|
||||
from letta.orm.mcp_oauth import OAuthSessionStatus
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.secret import Secret
|
||||
|
||||
|
||||
class BaseMCPServer(LettaBase):
|
||||
__id_prefix__ = "mcp_server"
|
||||
__id_prefix__ = PrimitiveType.MCP_SERVER.value
|
||||
|
||||
|
||||
# Create Schemas (for POST requests)
|
||||
@@ -101,7 +102,7 @@ UpdateMCPServerUnion = Union[UpdateStdioMCPServer, UpdateSSEMCPServer, UpdateStr
|
||||
|
||||
# OAuth-related schemas
|
||||
class BaseMCPOAuth(LettaBase):
|
||||
__id_prefix__ = "mcp-oauth"
|
||||
__id_prefix__ = PrimitiveType.MCP_OAUTH.value
|
||||
|
||||
|
||||
class MCPOAuthSession(BaseMCPOAuth):
|
||||
|
||||
@@ -4,12 +4,13 @@ from typing import Optional
|
||||
from pydantic import Field
|
||||
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.utils import create_random_username
|
||||
|
||||
|
||||
class OrganizationBase(LettaBase):
|
||||
__id_prefix__ = "org"
|
||||
__id_prefix__ = PrimitiveType.ORGANIZATION.value
|
||||
|
||||
|
||||
class Organization(OrganizationBase):
|
||||
|
||||
@@ -6,11 +6,12 @@ from pydantic import Field, field_validator
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
|
||||
|
||||
class PassageBase(OrmMetadataBase):
|
||||
__id_prefix__ = "passage"
|
||||
__id_prefix__ = PrimitiveType.PASSAGE.value
|
||||
|
||||
is_deleted: bool = Field(False, description="Whether this passage is deleted or not.")
|
||||
|
||||
|
||||
@@ -6,11 +6,12 @@ from typing import Any, Dict, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
|
||||
|
||||
class BaseProviderTrace(OrmMetadataBase):
|
||||
__id_prefix__ = "provider_trace"
|
||||
__id_prefix__ = PrimitiveType.PROVIDER_TRACE.value
|
||||
|
||||
|
||||
class ProviderTraceCreate(BaseModel):
|
||||
|
||||
@@ -2,11 +2,12 @@ from typing import List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
class RunMetricsBase(LettaBase):
|
||||
__id_prefix__ = "run"
|
||||
__id_prefix__ = PrimitiveType.RUN.value
|
||||
|
||||
|
||||
class RunMetrics(RunMetricsBase):
|
||||
|
||||
@@ -2,11 +2,12 @@ from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
class StepMetricsBase(LettaBase):
|
||||
__id_prefix__ = "step"
|
||||
__id_prefix__ = PrimitiveType.STEP.value
|
||||
|
||||
|
||||
class StepMetrics(StepMetricsBase):
|
||||
|
||||
@@ -4,14 +4,14 @@ from typing import Annotated, Any, Dict, List, Literal, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.enums import PrimitiveType, ToolRuleType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseToolRule(LettaBase):
|
||||
__id_prefix__ = "tool_rule"
|
||||
__id_prefix__ = PrimitiveType.TOOL_RULE.value
|
||||
tool_name: str = Field(..., description="The name of the tool. Must exist in the database for the user's organization.")
|
||||
type: ToolRuleType = Field(..., description="The type of the message.")
|
||||
prompt_template: Optional[str] = Field(
|
||||
|
||||
@@ -4,11 +4,12 @@ from typing import Optional
|
||||
from pydantic import Field
|
||||
|
||||
from letta.constants import DEFAULT_ORG_ID
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
class UserBase(LettaBase):
|
||||
__id_prefix__ = "user"
|
||||
__id_prefix__ = PrimitiveType.USER.value
|
||||
|
||||
|
||||
class User(UserBase):
|
||||
|
||||
@@ -63,6 +63,27 @@ SandboxConfigId = Annotated[str, PATH_VALIDATORS[PrimitiveType.SANDBOX_CONFIG.va
|
||||
StepId = Annotated[str, PATH_VALIDATORS[PrimitiveType.STEP.value]()]
|
||||
IdentityId = Annotated[str, PATH_VALIDATORS[PrimitiveType.IDENTITY.value]()]
|
||||
|
||||
# Infrastructure types
|
||||
McpServerId = Annotated[str, PATH_VALIDATORS[PrimitiveType.MCP_SERVER.value]()]
|
||||
McpOAuthId = Annotated[str, PATH_VALIDATORS[PrimitiveType.MCP_OAUTH.value]()]
|
||||
FileAgentId = Annotated[str, PATH_VALIDATORS[PrimitiveType.FILE_AGENT.value]()]
|
||||
|
||||
# Configuration types
|
||||
SandboxEnvId = Annotated[str, PATH_VALIDATORS[PrimitiveType.SANDBOX_ENV.value]()]
|
||||
AgentEnvId = Annotated[str, PATH_VALIDATORS[PrimitiveType.AGENT_ENV.value]()]
|
||||
|
||||
# Core entity types
|
||||
UserId = Annotated[str, PATH_VALIDATORS[PrimitiveType.USER.value]()]
|
||||
OrganizationId = Annotated[str, PATH_VALIDATORS[PrimitiveType.ORGANIZATION.value]()]
|
||||
ToolRuleId = Annotated[str, PATH_VALIDATORS[PrimitiveType.TOOL_RULE.value]()]
|
||||
|
||||
# Batch processing types
|
||||
BatchItemId = Annotated[str, PATH_VALIDATORS[PrimitiveType.BATCH_ITEM.value]()]
|
||||
BatchRequestId = Annotated[str, PATH_VALIDATORS[PrimitiveType.BATCH_REQUEST.value]()]
|
||||
|
||||
# Telemetry types
|
||||
ProviderTraceId = Annotated[str, PATH_VALIDATORS[PrimitiveType.PROVIDER_TRACE.value]()]
|
||||
|
||||
|
||||
def raise_on_invalid_id(param_name: str, expected_prefix: PrimitiveType):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user