diff --git a/letta/schemas/agent_file.py b/letta/schemas/agent_file.py index 5d90250b..ccbdd7f9 100644 --- a/letta/schemas/agent_file.py +++ b/letta/schemas/agent_file.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, Field from letta.helpers.datetime_helpers import get_utc_time from letta.schemas.agent import AgentState, CreateAgent from letta.schemas.block import Block, CreateBlock -from letta.schemas.enums import MessageRole +from letta.schemas.enums import MessageRole, PrimitiveType from letta.schemas.file import FileAgent, FileAgentBase, FileMetadata, FileMetadataBase from letta.schemas.group import Group, GroupCreate from letta.schemas.letta_message import ApprovalReturn @@ -42,7 +42,7 @@ class ImportResult: class MessageSchema(MessageCreate): """Message with human-readable ID for agent file""" - __id_prefix__ = "message" + __id_prefix__ = PrimitiveType.MESSAGE.value id: str = Field(..., description="Human-readable identifier for this message in the file") # Override the role field to accept all message roles, not just user/system/assistant @@ -96,7 +96,7 @@ class MessageSchema(MessageCreate): class FileAgentSchema(FileAgentBase): """File-Agent relationship with human-readable ID for agent file""" - __id_prefix__ = "file_agent" + __id_prefix__ = PrimitiveType.FILE_AGENT.value id: str = Field(..., description="Human-readable identifier for this file-agent relationship in the file") @classmethod @@ -120,7 +120,7 @@ class FileAgentSchema(FileAgentBase): class AgentSchema(CreateAgent): """Agent with human-readable ID for agent file""" - __id_prefix__ = "agent" + __id_prefix__ = PrimitiveType.AGENT.value id: str = Field(..., description="Human-readable identifier for this agent in the file") in_context_message_ids: List[str] = Field( default_factory=list, description="List of message IDs that are currently in the agent's context" @@ -198,7 +198,7 @@ class AgentSchema(CreateAgent): class GroupSchema(GroupCreate): """Group with human-readable ID for agent file""" - __id_prefix__ = "group" + __id_prefix__ = PrimitiveType.GROUP.value id: str = Field(..., description="Human-readable identifier for this group in the file") @classmethod @@ -220,7 +220,7 @@ class GroupSchema(GroupCreate): class BlockSchema(CreateBlock): """Block with human-readable ID for agent file""" - __id_prefix__ = "block" + __id_prefix__ = PrimitiveType.BLOCK.value id: str = Field(..., description="Human-readable identifier for this block in the file") @classmethod @@ -246,7 +246,7 @@ class BlockSchema(CreateBlock): class FileSchema(FileMetadataBase): """File with human-readable ID for agent file""" - __id_prefix__ = "file" + __id_prefix__ = PrimitiveType.FILE.value id: str = Field(..., description="Human-readable identifier for this file in the file") @classmethod @@ -276,7 +276,7 @@ class FileSchema(FileMetadataBase): class SourceSchema(SourceCreate): """Source with human-readable ID for agent file""" - __id_prefix__ = "source" + __id_prefix__ = PrimitiveType.SOURCE.value id: str = Field(..., description="Human-readable identifier for this source in the file") @classmethod @@ -299,7 +299,7 @@ class SourceSchema(SourceCreate): class ToolSchema(Tool): """Tool with human-readable ID for agent file""" - __id_prefix__ = "tool" + __id_prefix__ = PrimitiveType.TOOL.value id: str = Field(..., description="Human-readable identifier for this tool in the file") @classmethod @@ -311,7 +311,7 @@ class ToolSchema(Tool): class MCPServerSchema(BaseModel): """MCP server schema for agent files with remapped ID.""" - __id_prefix__ = "mcp_server" + __id_prefix__ = PrimitiveType.MCP_SERVER.value id: str = Field(..., description="Human-readable MCP server ID") server_type: str diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index 2ac0db12..6cb3ba2f 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -26,6 +26,27 @@ class PrimitiveType(str, Enum): STEP = "step" IDENTITY = "identity" + # Infrastructure types + MCP_SERVER = "mcp_server" + MCP_OAUTH = "mcp-oauth" + FILE_AGENT = "file_agent" + + # Configuration types + SANDBOX_ENV = "sandbox-env" + AGENT_ENV = "agent-env" + + # Core entity types + USER = "user" + ORGANIZATION = "org" + TOOL_RULE = "tool_rule" + + # Batch processing types + BATCH_ITEM = "batch_item" + BATCH_REQUEST = "batch_req" + + # Telemetry types + PROVIDER_TRACE = "provider_trace" + class ProviderType(str, Enum): anthropic = "anthropic" diff --git a/letta/schemas/environment_variables.py b/letta/schemas/environment_variables.py index 0a0e53c7..9485bcb9 100644 --- a/letta/schemas/environment_variables.py +++ b/letta/schemas/environment_variables.py @@ -2,6 +2,7 @@ from typing import Optional from pydantic import Field +from letta.schemas.enums import PrimitiveType from letta.schemas.letta_base import LettaBase, OrmMetadataBase from letta.schemas.secret import Secret from letta.settings import settings @@ -52,7 +53,7 @@ class EnvironmentVariableUpdateBase(LettaBase): # Environment Variable class SandboxEnvironmentVariableBase(EnvironmentVariableBase): - __id_prefix__ = "sandbox-env" + __id_prefix__ = PrimitiveType.SANDBOX_ENV.value sandbox_config_id: str = Field(..., description="The ID of the sandbox config this environment variable belongs to.") @@ -70,7 +71,7 @@ class SandboxEnvironmentVariableUpdate(EnvironmentVariableUpdateBase): # Agent-Specific Environment Variable class AgentEnvironmentVariableBase(EnvironmentVariableBase): - __id_prefix__ = "agent-env" + __id_prefix__ = PrimitiveType.AGENT_ENV.value agent_id: str = Field(..., description="The ID of the agent this environment variable belongs to.") diff --git a/letta/schemas/llm_batch_job.py b/letta/schemas/llm_batch_job.py index e07e148e..a69ee390 100644 --- a/letta/schemas/llm_batch_job.py +++ b/letta/schemas/llm_batch_job.py @@ -5,7 +5,7 @@ from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndi from pydantic import BaseModel, Field from letta.helpers import ToolRulesSolver -from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType +from letta.schemas.enums import AgentStepStatus, JobStatus, PrimitiveType, ProviderType from letta.schemas.letta_base import OrmMetadataBase from letta.schemas.llm_config import LLMConfig @@ -16,7 +16,7 @@ class AgentStepState(BaseModel): class LLMBatchItemBase(OrmMetadataBase, validate_assignment=True): - __id_prefix__ = "batch_item" + __id_prefix__ = PrimitiveType.BATCH_ITEM.value class LLMBatchItem(LLMBatchItemBase, validate_assignment=True): @@ -47,7 +47,7 @@ class LLMBatchJob(OrmMetadataBase, validate_assignment=True): Each job corresponds to one API call that sends multiple messages to the LLM provider, and aggregates responses across all agent submissions. """ - __id_prefix__ = "batch_req" + __id_prefix__ = PrimitiveType.BATCH_REQUEST.value id: Optional[str] = Field(None, description="The id of the batch job. Assigned by the database.") status: JobStatus = Field(..., description="The current status of the batch (e.g., created, in_progress, done).") diff --git a/letta/schemas/mcp.py b/letta/schemas/mcp.py index 92996344..d40f9865 100644 --- a/letta/schemas/mcp.py +++ b/letta/schemas/mcp.py @@ -13,13 +13,14 @@ from letta.functions.mcp_client.types import ( StreamableHTTPServerConfig, ) from letta.orm.mcp_oauth import OAuthSessionStatus +from letta.schemas.enums import PrimitiveType from letta.schemas.letta_base import LettaBase from letta.schemas.secret import Secret from letta.settings import settings class BaseMCPServer(LettaBase): - __id_prefix__ = "mcp_server" + __id_prefix__ = PrimitiveType.MCP_SERVER.value class MCPServer(BaseMCPServer): @@ -178,7 +179,7 @@ UpdateMCPServer = Union[UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamab # OAuth-related schemas class BaseMCPOAuth(LettaBase): - __id_prefix__ = "mcp-oauth" + __id_prefix__ = PrimitiveType.MCP_OAUTH.value class MCPOAuthSession(BaseMCPOAuth): diff --git a/letta/schemas/mcp_server.py b/letta/schemas/mcp_server.py index 5b495c5a..d0364a5f 100644 --- a/letta/schemas/mcp_server.py +++ b/letta/schemas/mcp_server.py @@ -13,12 +13,13 @@ from letta.functions.mcp_client.types import ( StreamableHTTPServerConfig, ) from letta.orm.mcp_oauth import OAuthSessionStatus +from letta.schemas.enums import PrimitiveType from letta.schemas.letta_base import LettaBase from letta.schemas.secret import Secret class BaseMCPServer(LettaBase): - __id_prefix__ = "mcp_server" + __id_prefix__ = PrimitiveType.MCP_SERVER.value # Create Schemas (for POST requests) @@ -101,7 +102,7 @@ UpdateMCPServerUnion = Union[UpdateStdioMCPServer, UpdateSSEMCPServer, UpdateStr # OAuth-related schemas class BaseMCPOAuth(LettaBase): - __id_prefix__ = "mcp-oauth" + __id_prefix__ = PrimitiveType.MCP_OAUTH.value class MCPOAuthSession(BaseMCPOAuth): diff --git a/letta/schemas/organization.py b/letta/schemas/organization.py index 9af86a14..e3d0f7c5 100644 --- a/letta/schemas/organization.py +++ b/letta/schemas/organization.py @@ -4,12 +4,13 @@ from typing import Optional from pydantic import Field from letta.helpers.datetime_helpers import get_utc_time +from letta.schemas.enums import PrimitiveType from letta.schemas.letta_base import LettaBase from letta.utils import create_random_username class OrganizationBase(LettaBase): - __id_prefix__ = "org" + __id_prefix__ = PrimitiveType.ORGANIZATION.value class Organization(OrganizationBase): diff --git a/letta/schemas/passage.py b/letta/schemas/passage.py index fdaac2f2..5396fbf3 100644 --- a/letta/schemas/passage.py +++ b/letta/schemas/passage.py @@ -6,11 +6,12 @@ from pydantic import Field, field_validator from letta.constants import MAX_EMBEDDING_DIM from letta.helpers.datetime_helpers import get_utc_time from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import PrimitiveType from letta.schemas.letta_base import OrmMetadataBase class PassageBase(OrmMetadataBase): - __id_prefix__ = "passage" + __id_prefix__ = PrimitiveType.PASSAGE.value is_deleted: bool = Field(False, description="Whether this passage is deleted or not.") diff --git a/letta/schemas/provider_trace.py b/letta/schemas/provider_trace.py index 382f4b8f..0e75d625 100644 --- a/letta/schemas/provider_trace.py +++ b/letta/schemas/provider_trace.py @@ -6,11 +6,12 @@ from typing import Any, Dict, Optional from pydantic import BaseModel, Field from letta.helpers.datetime_helpers import get_utc_time +from letta.schemas.enums import PrimitiveType from letta.schemas.letta_base import OrmMetadataBase class BaseProviderTrace(OrmMetadataBase): - __id_prefix__ = "provider_trace" + __id_prefix__ = PrimitiveType.PROVIDER_TRACE.value class ProviderTraceCreate(BaseModel): diff --git a/letta/schemas/run_metrics.py b/letta/schemas/run_metrics.py index 458a1cad..e7d21f2a 100644 --- a/letta/schemas/run_metrics.py +++ b/letta/schemas/run_metrics.py @@ -2,11 +2,12 @@ from typing import List, Optional from pydantic import Field +from letta.schemas.enums import PrimitiveType from letta.schemas.letta_base import LettaBase class RunMetricsBase(LettaBase): - __id_prefix__ = "run" + __id_prefix__ = PrimitiveType.RUN.value class RunMetrics(RunMetricsBase): diff --git a/letta/schemas/step_metrics.py b/letta/schemas/step_metrics.py index fb791fc0..321bd178 100644 --- a/letta/schemas/step_metrics.py +++ b/letta/schemas/step_metrics.py @@ -2,11 +2,12 @@ from typing import Optional from pydantic import Field +from letta.schemas.enums import PrimitiveType from letta.schemas.letta_base import LettaBase class StepMetricsBase(LettaBase): - __id_prefix__ = "step" + __id_prefix__ = PrimitiveType.STEP.value class StepMetrics(StepMetricsBase): diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py index e633749c..feba39e8 100644 --- a/letta/schemas/tool_rule.py +++ b/letta/schemas/tool_rule.py @@ -4,14 +4,14 @@ from typing import Annotated, Any, Dict, List, Literal, Optional, Set, Union from pydantic import BaseModel, Field, field_validator, model_validator -from letta.schemas.enums import ToolRuleType +from letta.schemas.enums import PrimitiveType, ToolRuleType from letta.schemas.letta_base import LettaBase logger = logging.getLogger(__name__) class BaseToolRule(LettaBase): - __id_prefix__ = "tool_rule" + __id_prefix__ = PrimitiveType.TOOL_RULE.value tool_name: str = Field(..., description="The name of the tool. Must exist in the database for the user's organization.") type: ToolRuleType = Field(..., description="The type of the message.") prompt_template: Optional[str] = Field( diff --git a/letta/schemas/user.py b/letta/schemas/user.py index 1b92058e..1a657af4 100644 --- a/letta/schemas/user.py +++ b/letta/schemas/user.py @@ -4,11 +4,12 @@ from typing import Optional from pydantic import Field from letta.constants import DEFAULT_ORG_ID +from letta.schemas.enums import PrimitiveType from letta.schemas.letta_base import LettaBase class UserBase(LettaBase): - __id_prefix__ = "user" + __id_prefix__ = PrimitiveType.USER.value class User(UserBase): diff --git a/letta/validators.py b/letta/validators.py index 42726f5b..ed738eba 100644 --- a/letta/validators.py +++ b/letta/validators.py @@ -63,6 +63,27 @@ SandboxConfigId = Annotated[str, PATH_VALIDATORS[PrimitiveType.SANDBOX_CONFIG.va StepId = Annotated[str, PATH_VALIDATORS[PrimitiveType.STEP.value]()] IdentityId = Annotated[str, PATH_VALIDATORS[PrimitiveType.IDENTITY.value]()] +# Infrastructure types +McpServerId = Annotated[str, PATH_VALIDATORS[PrimitiveType.MCP_SERVER.value]()] +McpOAuthId = Annotated[str, PATH_VALIDATORS[PrimitiveType.MCP_OAUTH.value]()] +FileAgentId = Annotated[str, PATH_VALIDATORS[PrimitiveType.FILE_AGENT.value]()] + +# Configuration types +SandboxEnvId = Annotated[str, PATH_VALIDATORS[PrimitiveType.SANDBOX_ENV.value]()] +AgentEnvId = Annotated[str, PATH_VALIDATORS[PrimitiveType.AGENT_ENV.value]()] + +# Core entity types +UserId = Annotated[str, PATH_VALIDATORS[PrimitiveType.USER.value]()] +OrganizationId = Annotated[str, PATH_VALIDATORS[PrimitiveType.ORGANIZATION.value]()] +ToolRuleId = Annotated[str, PATH_VALIDATORS[PrimitiveType.TOOL_RULE.value]()] + +# Batch processing types +BatchItemId = Annotated[str, PATH_VALIDATORS[PrimitiveType.BATCH_ITEM.value]()] +BatchRequestId = Annotated[str, PATH_VALIDATORS[PrimitiveType.BATCH_REQUEST.value]()] + +# Telemetry types +ProviderTraceId = Annotated[str, PATH_VALIDATORS[PrimitiveType.PROVIDER_TRACE.value]()] + def raise_on_invalid_id(param_name: str, expected_prefix: PrimitiveType): """