diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 3662ecf6..26050c33 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -8,6 +8,7 @@ from letta.constants import CORE_MEMORY_LINE_NUMBER_WARNING, DEFAULT_EMBEDDING_C from letta.errors import AgentExportProcessingError from letta.schemas.block import Block, CreateBlock from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import PrimitiveType from letta.schemas.environment_variables import AgentEnvironmentVariable from letta.schemas.file import FileStatus from letta.schemas.group import Group @@ -57,7 +58,7 @@ class AgentState(OrmMetadataBase, validate_assignment=True): embedding_config (EmbeddingConfig): The embedding configuration used by the agent. """ - __id_prefix__ = "agent" + __id_prefix__ = PrimitiveType.AGENT.value # NOTE: this is what is returned to the client and also what is used to initialize `Agent` id: str = Field(..., description="The id of the agent. Assigned by the database.") diff --git a/letta/schemas/archive.py b/letta/schemas/archive.py index 55727e92..cd8e2ac0 100644 --- a/letta/schemas/archive.py +++ b/letta/schemas/archive.py @@ -3,12 +3,12 @@ from typing import Dict, Optional from pydantic import Field -from letta.schemas.enums import VectorDBProvider +from letta.schemas.enums import PrimitiveType, VectorDBProvider from letta.schemas.letta_base import OrmMetadataBase class ArchiveBase(OrmMetadataBase): - __id_prefix__ = "archive" + __id_prefix__ = PrimitiveType.ARCHIVE.value name: str = Field(..., description="The name of the archive") description: Optional[str] = Field(None, description="A description of the archive") diff --git a/letta/schemas/block.py b/letta/schemas/block.py index b18ee8b9..bac5ed02 100644 --- a/letta/schemas/block.py +++ b/letta/schemas/block.py @@ -4,6 +4,7 @@ from typing import Any, Optional from pydantic import ConfigDict, Field, model_validator from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT, DEFAULT_HUMAN_BLOCK_DESCRIPTION, DEFAULT_PERSONA_BLOCK_DESCRIPTION +from letta.schemas.enums import PrimitiveType from letta.schemas.letta_base import LettaBase # block of the LLM context @@ -12,7 +13,7 @@ from letta.schemas.letta_base import LettaBase class BaseBlock(LettaBase, validate_assignment=True): """Base block of the LLM context""" - __id_prefix__ = "block" + __id_prefix__ = PrimitiveType.BLOCK.value # data value value: str = Field(..., description="Value of the block.") diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index 4120b81d..b7f72457 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -1,6 +1,31 @@ from enum import Enum, StrEnum +class PrimitiveType(str, Enum): + """ + Enum for all primitive resource types in Letta. + + The enum values ARE the actual ID prefixes used in the system. + This serves as the single source of truth for all ID prefixes. + """ + + AGENT = "agent" + MESSAGE = "message" + RUN = "run" + JOB = "job" + GROUP = "group" + BLOCK = "block" + FILE = "file" + FOLDER = "source" # Note: folder IDs use "source" prefix for historical reasons + SOURCE = "source" + TOOL = "tool" + ARCHIVE = "archive" + PROVIDER = "provider" + SANDBOX_CONFIG = "sandbox" # Note: sandbox_config IDs use "sandbox" prefix + STEP = "step" + IDENTITY = "identity" + + class ProviderType(str, Enum): anthropic = "anthropic" azure = "azure" diff --git a/letta/schemas/file.py b/letta/schemas/file.py index 93e36d67..6e305fb6 100644 --- a/letta/schemas/file.py +++ b/letta/schemas/file.py @@ -4,7 +4,7 @@ from typing import List, Optional from pydantic import Field -from letta.schemas.enums import FileProcessingStatus +from letta.schemas.enums import FileProcessingStatus, PrimitiveType from letta.schemas.letta_base import LettaBase @@ -20,7 +20,7 @@ class FileStatus(str, Enum): class FileMetadataBase(LettaBase): """Base class for FileMetadata schemas""" - __id_prefix__ = "file" + __id_prefix__ = PrimitiveType.FILE.value # Core file metadata fields source_id: str = Field(..., description="The unique identifier of the source associated with the document.") @@ -61,7 +61,7 @@ class FileMetadata(FileMetadataBase): class FileAgentBase(LettaBase): """Base class for the FileMetadata-⇄-Agent association schemas""" - __id_prefix__ = "file_agent" + __id_prefix__ = PrimitiveType.FILE.value # Core file-agent association fields agent_id: str = Field(..., description="Unique identifier of the agent.") diff --git a/letta/schemas/folder.py b/letta/schemas/folder.py index a60aa2cd..149dd3a1 100644 --- a/letta/schemas/folder.py +++ b/letta/schemas/folder.py @@ -4,6 +4,7 @@ from typing import Optional from pydantic import Field from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import PrimitiveType from letta.schemas.letta_base import LettaBase @@ -12,7 +13,7 @@ class BaseFolder(LettaBase): Shared attributes across all folder schemas. """ - __id_prefix__ = "source" # TODO: change to "folder" + __id_prefix__ = PrimitiveType.FOLDER.value # TODO: change to "folder" # Core folder fields name: str = Field(..., description="The name of the folder.") diff --git a/letta/schemas/group.py b/letta/schemas/group.py index 2bc82c89..a5c06c0c 100644 --- a/letta/schemas/group.py +++ b/letta/schemas/group.py @@ -3,6 +3,7 @@ from typing import Annotated, List, Literal, Optional, Union from pydantic import BaseModel, Field +from letta.schemas.enums import PrimitiveType from letta.schemas.letta_base import LettaBase @@ -20,7 +21,7 @@ class ManagerConfig(BaseModel): class GroupBase(LettaBase): - __id_prefix__ = "group" + __id_prefix__ = PrimitiveType.GROUP.value class Group(GroupBase): diff --git a/letta/schemas/identity.py b/letta/schemas/identity.py index 147683d5..83bd879f 100644 --- a/letta/schemas/identity.py +++ b/letta/schemas/identity.py @@ -3,6 +3,7 @@ from typing import List, Optional, Union from pydantic import Field +from letta.schemas.enums import PrimitiveType from letta.schemas.letta_base import LettaBase @@ -28,7 +29,7 @@ class IdentityPropertyType(str, Enum): class IdentityBase(LettaBase): - __id_prefix__ = "identity" + __id_prefix__ = PrimitiveType.IDENTITY.value class IdentityProperty(LettaBase): diff --git a/letta/schemas/job.py b/letta/schemas/job.py index 007381fd..26393967 100644 --- a/letta/schemas/job.py +++ b/letta/schemas/job.py @@ -3,6 +3,8 @@ from typing import TYPE_CHECKING, List, Optional from pydantic import BaseModel, ConfigDict, Field +from letta.schemas.enums import PrimitiveType + if TYPE_CHECKING: from letta.schemas.letta_request import LettaRequest @@ -15,7 +17,7 @@ from letta.schemas.letta_stop_reason import StopReasonType class JobBase(OrmMetadataBase): - __id_prefix__ = "job" + __id_prefix__ = PrimitiveType.JOB.value status: JobStatus = Field(default=JobStatus.created, description="The status of the job.") created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the job was created.") diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 937593d0..d6218c92 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -19,7 +19,7 @@ from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, RE from letta.helpers.datetime_helpers import get_utc_time, is_utc_datetime from letta.helpers.json_helpers import json_dumps from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_VERTEX -from letta.schemas.enums import MessageRole +from letta.schemas.enums import MessageRole, PrimitiveType from letta.schemas.letta_base import OrmMetadataBase from letta.schemas.letta_message import ( ApprovalRequestMessage, @@ -170,7 +170,7 @@ class MessageUpdate(BaseModel): class BaseMessage(OrmMetadataBase): - __id_prefix__ = "message" + __id_prefix__ = PrimitiveType.MESSAGE.value class Message(BaseMessage): diff --git a/letta/schemas/providers/base.py b/letta/schemas/providers/base.py index a88fd5a7..1de956b2 100644 --- a/letta/schemas/providers/base.py +++ b/letta/schemas/providers/base.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field, model_validator from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES -from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.enums import PrimitiveType, ProviderCategory, ProviderType from letta.schemas.letta_base import LettaBase from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES @@ -13,7 +13,7 @@ from letta.settings import model_settings class ProviderBase(LettaBase): - __id_prefix__ = "provider" + __id_prefix__ = PrimitiveType.PROVIDER.value class Provider(ProviderBase): diff --git a/letta/schemas/run.py b/letta/schemas/run.py index 8b3bf6fb..e01f3781 100644 --- a/letta/schemas/run.py +++ b/letta/schemas/run.py @@ -4,14 +4,14 @@ from typing import Optional from pydantic import ConfigDict, Field from letta.helpers.datetime_helpers import get_utc_time -from letta.schemas.enums import RunStatus +from letta.schemas.enums import PrimitiveType, RunStatus from letta.schemas.job import LettaRequestConfig from letta.schemas.letta_base import LettaBase from letta.schemas.letta_stop_reason import StopReasonType class RunBase(LettaBase): - __id_prefix__ = "run" + __id_prefix__ = PrimitiveType.RUN.value class Run(RunBase): diff --git a/letta/schemas/sandbox_config.py b/letta/schemas/sandbox_config.py index b4869563..045b892d 100644 --- a/letta/schemas/sandbox_config.py +++ b/letta/schemas/sandbox_config.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, model_validator from letta.constants import LETTA_TOOL_EXECUTION_DIR from letta.schemas.agent import AgentState -from letta.schemas.enums import SandboxType +from letta.schemas.enums import PrimitiveType, SandboxType from letta.schemas.letta_base import LettaBase, OrmMetadataBase from letta.schemas.pip_requirement import PipRequirement from letta.services.tool_sandbox.modal_constants import DEFAULT_MODAL_TIMEOUT @@ -92,7 +92,7 @@ class ModalSandboxConfig(BaseModel): class SandboxConfigBase(OrmMetadataBase): - __id_prefix__ = "sandbox" + __id_prefix__ = PrimitiveType.SANDBOX_CONFIG.value class SandboxConfig(SandboxConfigBase): diff --git a/letta/schemas/source.py b/letta/schemas/source.py index cd816ef3..69a16719 100644 --- a/letta/schemas/source.py +++ b/letta/schemas/source.py @@ -5,7 +5,7 @@ from pydantic import Field from letta.helpers.tpuf_client import should_use_tpuf from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import VectorDBProvider +from letta.schemas.enums import PrimitiveType, VectorDBProvider from letta.schemas.letta_base import LettaBase @@ -14,7 +14,7 @@ class BaseSource(LettaBase): Shared attributes across all source schemas. """ - __id_prefix__ = "source" + __id_prefix__ = PrimitiveType.SOURCE.value # Core source fields name: str = Field(..., description="The name of the source.") diff --git a/letta/schemas/step.py b/letta/schemas/step.py index ec5e7bdc..38eb8cde 100644 --- a/letta/schemas/step.py +++ b/letta/schemas/step.py @@ -3,14 +3,14 @@ from typing import Dict, List, Literal, Optional from pydantic import Field -from letta.schemas.enums import StepStatus +from letta.schemas.enums import PrimitiveType, StepStatus from letta.schemas.letta_base import LettaBase from letta.schemas.letta_stop_reason import StopReasonType from letta.schemas.message import Message class StepBase(LettaBase): - __id_prefix__ = "step" + __id_prefix__ = PrimitiveType.STEP.value class Step(StepBase): diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index 74622192..fcaa8d33 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -11,6 +11,7 @@ from letta.constants import ( LETTA_VOICE_TOOL_MODULE_NAME, MCP_TOOL_TAG_NAME_PREFIX, ) +from letta.schemas.enums import PrimitiveType # MCP Tool metadata constants for schema health status MCP_TOOL_METADATA_SCHEMA_STATUS = f"{MCP_TOOL_TAG_NAME_PREFIX}:SCHEMA_STATUS" @@ -28,7 +29,7 @@ logger = get_logger(__name__) class BaseTool(LettaBase): - __id_prefix__ = "tool" + __id_prefix__ = PrimitiveType.TOOL.value class Tool(BaseTool): diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index d57d736c..dd94a74c 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -58,7 +58,7 @@ from letta.schemas.agent import ( ) from letta.schemas.block import DEFAULT_BLOCKS, Block as PydanticBlock, BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import AgentType, ProviderType, TagMatchMode, ToolType, VectorDBProvider +from letta.schemas.enums import AgentType, PrimitiveType, ProviderType, TagMatchMode, ToolType, VectorDBProvider from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.group import Group as PydanticGroup, ManagerType from letta.schemas.llm_config import LLMConfig @@ -110,7 +110,7 @@ from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager from letta.settings import DatabaseChoice, model_settings, settings from letta.utils import calculate_file_defaults_based_on_context_window, enforce_types, united_diff -from letta.validators import is_valid_id +from letta.validators import raise_on_invalid_id logger = get_logger(__name__) @@ -672,6 +672,7 @@ class AgentManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def update_agent_async( self, agent_id: str, @@ -976,6 +977,7 @@ class AgentManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def get_agent_by_id_async( self, agent_id: str, @@ -984,10 +986,6 @@ class AgentManager: ) -> PydanticAgentState: """Fetch an agent by its ID.""" - # Check if agent_id matches uuid4 format - if not is_valid_id("agent", agent_id): - raise LettaAgentNotFoundError(f"agent_id {agent_id} is not in the valid format 'agent-'") - async with db_registry.async_session() as session: try: query = select(AgentModel) @@ -1039,6 +1037,7 @@ class AgentManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def get_agent_archive_ids_async(self, agent_id: str, actor: PydanticUser) -> List[str]: """Get all archive IDs associated with an agent.""" from letta.orm import ArchivesAgents @@ -1052,6 +1051,7 @@ class AgentManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def validate_agent_exists_async(self, agent_id: str, actor: PydanticUser) -> None: """ Validate that an agent exists and user has access to it. @@ -1069,6 +1069,7 @@ class AgentManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def delete_agent_async(self, agent_id: str, actor: PydanticUser) -> None: """ Deletes an agent and its associated relationships. @@ -1131,6 +1132,7 @@ class AgentManager: # TODO: This can also be made more efficient, instead of getting, setting, we can do it all in one db session for one query. @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def get_in_context_messages(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]: agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor) return await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor) @@ -1143,6 +1145,7 @@ class AgentManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def get_system_message_async(self, agent_id: str, actor: PydanticUser) -> PydanticMessage: agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor) return await self.message_manager.get_message_by_id_async(message_id=agent.message_ids[0], actor=actor) @@ -1324,6 +1327,7 @@ class AgentManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def set_in_context_messages_async(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState: return await self.update_agent_async(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor) @@ -1434,6 +1438,7 @@ class AgentManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def update_memory_if_changed_async(self, agent_id: str, new_memory: Memory, actor: PydanticUser) -> PydanticAgentState: """ Update internal memory object and system prompt if there have been modifications. @@ -1546,6 +1551,8 @@ class AgentManager: # ====================================================================================================================== @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) + @raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE) async def attach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState: """ Attaches a source to an agent. @@ -1615,6 +1622,7 @@ class AgentManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def append_system_message_async(self, agent_id: str, content: str, actor: PydanticUser): """ Async version of append_system_message. @@ -1702,6 +1710,8 @@ class AgentManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) + @raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE) async def detach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState: """ Detaches a source from an agent. @@ -1787,6 +1797,8 @@ class AgentManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) + @raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK) async def attach_block_async(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState: """Attaches a block to an agent. For sleeptime agents, also attaches to paired agents in the same group.""" async with db_registry.async_session() as session: @@ -2373,8 +2385,9 @@ class AgentManager: # Tool Management # ====================================================================================================================== @enforce_types - @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) + @raise_on_invalid_id(param_name="tool_id", expected_prefix=PrimitiveType.TOOL) async def attach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> None: """ Attaches a tool to an agent. @@ -2443,6 +2456,7 @@ class AgentManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def bulk_attach_tools_async(self, agent_id: str, tool_ids: List[str], actor: PydanticUser) -> None: """ Efficiently attaches multiple tools to an agent in a single operation. @@ -2608,6 +2622,8 @@ class AgentManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) + @raise_on_invalid_id(param_name="tool_id", expected_prefix=PrimitiveType.TOOL) async def detach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> None: """ Detaches a tool from an agent. @@ -2637,6 +2653,7 @@ class AgentManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def bulk_detach_tools_async(self, agent_id: str, tool_ids: List[str], actor: PydanticUser) -> None: """ Efficiently detaches multiple tools from an agent in a single operation. @@ -2673,6 +2690,7 @@ class AgentManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def modify_approvals_async(self, agent_id: str, tool_name: str, requires_approval: bool, actor: PydanticUser) -> None: def is_target_rule(rule): return rule.tool_name == tool_name and rule.type == "requires_approval" @@ -3021,6 +3039,7 @@ class AgentManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def get_agent_files_config_async(self, agent_id: str, actor: PydanticUser) -> Tuple[int, int]: """Get per_file_view_window_char_limit and max_files_open for an agent. @@ -3077,6 +3096,7 @@ class AgentManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def get_agent_max_files_open_async(self, agent_id: str, actor: PydanticUser) -> int: """Get max_files_open for an agent. @@ -3105,6 +3125,7 @@ class AgentManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def get_agent_per_file_view_window_char_limit_async(self, agent_id: str, actor: PydanticUser) -> int: """Get per_file_view_window_char_limit for an agent. @@ -3131,7 +3152,9 @@ class AgentManager: return row + @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def get_context_window(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview: agent_state, system_message, num_messages, num_archival_memories = await self.rebuild_system_prompt_async( agent_id=agent_id, actor=actor, force=True, dry_run=True diff --git a/letta/services/archive_manager.py b/letta/services/archive_manager.py index 203a1ffb..7eeb51d1 100644 --- a/letta/services/archive_manager.py +++ b/letta/services/archive_manager.py @@ -7,11 +7,12 @@ from letta.log import get_logger from letta.orm import ArchivalPassage, Archive as ArchiveModel, ArchivesAgents from letta.otel.tracing import trace_method from letta.schemas.archive import Archive as PydanticArchive -from letta.schemas.enums import VectorDBProvider +from letta.schemas.enums import PrimitiveType, VectorDBProvider from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.settings import settings from letta.utils import enforce_types +from letta.validators import raise_on_invalid_id logger = get_logger(__name__) @@ -47,6 +48,7 @@ class ArchiveManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE) async def get_archive_by_id_async( self, archive_id: str, @@ -63,6 +65,7 @@ class ArchiveManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE) async def update_archive_async( self, archive_id: str, @@ -89,6 +92,7 @@ class ArchiveManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def list_archives_async( self, *, @@ -136,6 +140,8 @@ class ArchiveManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) + @raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE) async def attach_agent_to_archive_async( self, agent_id: str, @@ -172,6 +178,7 @@ class ArchiveManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def get_default_archive_for_agent_async( self, agent_id: str, @@ -204,6 +211,7 @@ class ArchiveManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE) async def delete_archive_async( self, archive_id: str, @@ -221,6 +229,7 @@ class ArchiveManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def get_or_create_default_archive_for_agent_async( self, agent_id: str, @@ -291,6 +300,7 @@ class ArchiveManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE) async def get_agents_for_archive_async( self, archive_id: str, @@ -333,6 +343,7 @@ class ArchiveManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE) async def get_or_set_vector_db_namespace_async( self, archive_id: str, diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index 8231f854..1fbdce1e 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -15,11 +15,12 @@ from letta.orm.errors import NoResultFound from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState as PydanticAgentState from letta.schemas.block import Block as PydanticBlock, BlockUpdate -from letta.schemas.enums import ActorType +from letta.schemas.enums import ActorType, PrimitiveType from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.settings import DatabaseChoice, settings from letta.utils import enforce_types +from letta.validators import raise_on_invalid_id logger = get_logger(__name__) @@ -134,10 +135,9 @@ class BlockManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK) async def update_block_async(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock: """Update a block by its ID with the given BlockUpdate object.""" - # Safety check for block - async with db_registry.async_session() as session: block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor) update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) @@ -155,6 +155,7 @@ class BlockManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK) async def delete_block_async(self, block_id: str, actor: PydanticUser) -> None: """Delete a block by its ID.""" async with db_registry.async_session() as session: @@ -353,6 +354,7 @@ class BlockManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK) async def get_block_by_id_async(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]: """Retrieve a block by its name.""" async with db_registry.async_session() as session: @@ -412,6 +414,7 @@ class BlockManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK) async def get_agents_for_block_async( self, block_id: str, @@ -595,6 +598,8 @@ class BlockManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK) + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def checkpoint_block_async( self, block_id: str, @@ -703,6 +708,7 @@ class BlockManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK) async def undo_checkpoint_block( self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None ) -> PydanticBlock: @@ -753,6 +759,7 @@ class BlockManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK) async def redo_checkpoint_block( self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None ) -> PydanticBlock: diff --git a/letta/services/file_manager.py b/letta/services/file_manager.py index 256b9e7b..596b6356 100644 --- a/letta/services/file_manager.py +++ b/letta/services/file_manager.py @@ -15,7 +15,7 @@ from letta.orm.errors import NoResultFound from letta.orm.file import FileContent as FileContentModel, FileMetadata as FileMetadataModel from letta.orm.sqlalchemy_base import AccessType from letta.otel.tracing import trace_method -from letta.schemas.enums import FileProcessingStatus +from letta.schemas.enums import FileProcessingStatus, PrimitiveType from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.source import Source as PydanticSource from letta.schemas.source_metadata import FileStats, OrganizationSourcesStats, SourceStats @@ -23,6 +23,7 @@ from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.settings import settings from letta.utils import enforce_types +from letta.validators import raise_on_invalid_id logger = get_logger(__name__) @@ -93,6 +94,7 @@ class FileManager: # TODO: We make actor optional for now, but should most likely be enforced due to security reasons @enforce_types @trace_method + @raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE) # @async_redis_cache( # key_func=lambda self, file_id, actor=None, include_content=False, strip_directory_prefix=False: f"{file_id}:{actor.organization_id if actor else 'none'}:{include_content}:{strip_directory_prefix}", # prefix="file_content", @@ -135,6 +137,7 @@ class FileManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE) async def update_file_status( self, *, @@ -171,7 +174,6 @@ class FileManager: * 1st round-trip → UPDATE with optional state validation * 2nd round-trip → SELECT fresh row (same as read_async) if update succeeded """ - if processing_status is None and error_message is None and total_chunks is None and chunks_embedded is None: raise ValueError("Nothing to update") @@ -353,6 +355,7 @@ class FileManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE) async def upsert_file_content( self, *, @@ -398,6 +401,7 @@ class FileManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE) async def list_files( self, source_id: str, @@ -455,6 +459,7 @@ class FileManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE) async def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata: """Delete a file by its ID.""" async with db_registry.async_session() as session: @@ -509,6 +514,7 @@ class FileManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE) # @async_redis_cache( # key_func=lambda self, original_filename, source_id, actor: f"{original_filename}:{source_id}:{actor.organization_id}", # prefix="file_by_name", diff --git a/letta/services/group_manager.py b/letta/services/group_manager.py index 3f67e003..a16b942e 100644 --- a/letta/services/group_manager.py +++ b/letta/services/group_manager.py @@ -9,6 +9,7 @@ from letta.orm.errors import NoResultFound from letta.orm.group import Group as GroupModel from letta.orm.message import Message as MessageModel from letta.otel.tracing import trace_method +from letta.schemas.enums import PrimitiveType from letta.schemas.group import Group as PydanticGroup, GroupCreate, GroupUpdate, InternalTemplateGroupCreate, ManagerType from letta.schemas.letta_message import LettaMessage from letta.schemas.message import Message as PydanticMessage @@ -16,6 +17,7 @@ from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.settings import DatabaseChoice, settings from letta.utils import enforce_types +from letta.validators import raise_on_invalid_id class GroupManager: @@ -62,6 +64,7 @@ class GroupManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP) async def retrieve_group_async(self, group_id: str, actor: PydanticUser) -> PydanticGroup: async with db_registry.async_session() as session: group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor) @@ -119,6 +122,7 @@ class GroupManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP) async def modify_group_async(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup: async with db_registry.async_session() as session: group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor) @@ -182,6 +186,7 @@ class GroupManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP) async def delete_group_async(self, group_id: str, actor: PydanticUser) -> None: async with db_registry.async_session() as session: group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor) @@ -189,6 +194,7 @@ class GroupManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP) async def list_group_messages_async( self, actor: PydanticUser, @@ -226,6 +232,7 @@ class GroupManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP) async def reset_messages_async(self, group_id: str, actor: PydanticUser) -> None: async with db_registry.async_session() as session: # Ensure group is loadable by user @@ -241,6 +248,7 @@ class GroupManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP) async def bump_turns_counter_async(self, group_id: str, actor: PydanticUser) -> int: async with db_registry.async_session() as session: # Ensure group is loadable by user @@ -253,6 +261,8 @@ class GroupManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP) + @raise_on_invalid_id(param_name="last_processed_message_id", expected_prefix=PrimitiveType.MESSAGE) async def get_last_processed_message_id_and_update_async( self, group_id: str, last_processed_message_id: str, actor: PydanticUser ) -> str: diff --git a/letta/services/identity_manager.py b/letta/services/identity_manager.py index 2a01263c..f84d5d73 100644 --- a/letta/services/identity_manager.py +++ b/letta/services/identity_manager.py @@ -12,6 +12,7 @@ from letta.orm.identity import Identity as IdentityModel from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState from letta.schemas.block import Block +from letta.schemas.enums import PrimitiveType from letta.schemas.identity import ( Identity as PydanticIdentity, IdentityCreate, @@ -24,6 +25,7 @@ from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.settings import DatabaseChoice, settings from letta.utils import enforce_types +from letta.validators import raise_on_invalid_id class IdentityManager: @@ -62,6 +64,7 @@ class IdentityManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY) async def get_identity_async(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity: async with db_registry.async_session() as session: identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor) @@ -143,6 +146,7 @@ class IdentityManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY) async def update_identity_async( self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False ) -> PydanticIdentity: @@ -206,6 +210,7 @@ class IdentityManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY) async def upsert_identity_properties_async( self, identity_id: str, properties: List[IdentityProperty], actor: PydanticUser ) -> PydanticIdentity: @@ -223,6 +228,7 @@ class IdentityManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY) async def delete_identity_async(self, identity_id: str, actor: PydanticUser) -> None: async with db_registry.async_session() as session: identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor) @@ -280,6 +286,7 @@ class IdentityManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY) async def list_agents_for_identity_async( self, identity_id: str, @@ -311,6 +318,7 @@ class IdentityManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY) async def list_blocks_for_identity_async( self, identity_id: str, diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index 615d0866..33a6e674 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -14,7 +14,7 @@ from letta.orm.message import Message as MessageModel from letta.orm.sqlalchemy_base import AccessType from letta.orm.step import Step, Step as StepModel from letta.otel.tracing import log_event, trace_method -from letta.schemas.enums import JobStatus, JobType, MessageRole +from letta.schemas.enums import JobStatus, JobType, MessageRole, PrimitiveType from letta.schemas.job import BatchJob as PydanticBatchJob, Job as PydanticJob, JobUpdate, LettaRequestConfig from letta.schemas.letta_message import LettaMessage from letta.schemas.letta_stop_reason import StopReasonType @@ -26,6 +26,7 @@ from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.services.helpers.agent_manager_helper import validate_agent_exists_async from letta.utils import enforce_types +from letta.validators import raise_on_invalid_id logger = get_logger(__name__) @@ -70,6 +71,7 @@ class JobManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB) async def update_job_by_id_async( self, job_id: str, job_update: JobUpdate, actor: PydanticUser, safe_update: bool = False ) -> PydanticJob: @@ -147,6 +149,7 @@ class JobManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB) async def safe_update_job_status_async( self, job_id: str, @@ -187,6 +190,7 @@ class JobManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB) async def get_job_by_id_async(self, job_id: str, actor: PydanticUser) -> PydanticJob: """Fetch a job by its ID asynchronously.""" async with db_registry.async_session() as session: @@ -301,6 +305,7 @@ class JobManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB) async def delete_job_by_id_async(self, job_id: str, actor: PydanticUser) -> PydanticJob: """Delete a job by its ID.""" async with db_registry.async_session() as session: @@ -310,6 +315,7 @@ class JobManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) async def get_run_messages( self, run_id: str, @@ -367,6 +373,7 @@ class JobManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) async def get_step_messages( self, run_id: str, @@ -447,6 +454,7 @@ class JobManager: return job @enforce_types + @raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB) async def record_ttft(self, job_id: str, ttft_ns: int, actor: PydanticUser) -> None: """Record time to first token for a run""" try: @@ -459,6 +467,7 @@ class JobManager: logger.warning(f"Failed to record TTFT for job {job_id}: {e}") @enforce_types + @raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB) async def record_response_duration(self, job_id: str, total_duration_ns: int, actor: PydanticUser) -> None: """Record total response duration for a run""" try: @@ -529,6 +538,7 @@ class JobManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB) async def get_job_steps( self, job_id: str, diff --git a/letta/services/mcp_manager.py b/letta/services/mcp_manager.py index 60095e04..77475455 100644 --- a/letta/services/mcp_manager.py +++ b/letta/services/mcp_manager.py @@ -25,6 +25,7 @@ from letta.orm.errors import NoResultFound from letta.orm.mcp_oauth import MCPOAuth, OAuthSessionStatus from letta.orm.mcp_server import MCPServer as MCPServerModel from letta.orm.tool import Tool as ToolModel +from letta.schemas.enums import PrimitiveType from letta.schemas.mcp import ( MCPOAuthSession, MCPOAuthSessionCreate, @@ -47,6 +48,7 @@ from letta.services.mcp.streamable_http_client import AsyncStreamableHTTPMCPClie from letta.services.tool_manager import ToolManager from letta.settings import settings, tool_settings from letta.utils import enforce_types, printd, safe_create_task_with_return +from letta.validators import raise_on_invalid_id logger = get_logger(__name__) @@ -60,6 +62,7 @@ class MCPManager: self.cached_mcp_servers = {} # maps id -> async connection @enforce_types + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) async def list_mcp_server_tools(self, mcp_server_name: str, actor: PydanticUser, agent_id: Optional[str] = None) -> List[MCPTool]: """Get a list of all tools for a specific MCP server.""" mcp_client = None diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index f457e75e..a66e059a 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -10,7 +10,7 @@ from letta.log import get_logger from letta.orm.errors import NoResultFound from letta.orm.message import Message as MessageModel from letta.otel.tracing import trace_method -from letta.schemas.enums import MessageRole +from letta.schemas.enums import MessageRole, PrimitiveType from letta.schemas.letta_message import LettaMessageUpdateUnion from letta.schemas.letta_message_content import ImageSourceType, LettaImage, MessageContentType, TextContent from letta.schemas.message import Message as PydanticMessage, MessageSearchResult, MessageUpdate @@ -20,6 +20,7 @@ from letta.services.file_manager import FileManager from letta.services.helpers.agent_manager_helper import validate_agent_exists_async from letta.settings import DatabaseChoice, settings from letta.utils import enforce_types, fire_and_forget +from letta.validators import raise_on_invalid_id logger = get_logger(__name__) @@ -308,6 +309,7 @@ class MessageManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="message_id", expected_prefix=PrimitiveType.MESSAGE) async def get_message_by_id_async(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]: """Fetch a message by ID.""" async with db_registry.async_session() as session: @@ -712,6 +714,7 @@ class MessageManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="message_id", expected_prefix=PrimitiveType.MESSAGE) async def delete_message_by_id_async(self, message_id: str, actor: PydanticUser, strict_mode: bool = False) -> bool: """Delete a message (async version with turbopuffer support).""" # capture agent_id before deletion diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 108b9a19..907bdc49 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -2,12 +2,13 @@ from typing import List, Optional, Tuple, Union from letta.orm.provider import Provider as ProviderModel from letta.otel.tracing import trace_method -from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.enums import PrimitiveType, ProviderCategory, ProviderType from letta.schemas.providers import Provider as PydanticProvider, ProviderCheck, ProviderCreate, ProviderUpdate from letta.schemas.secret import Secret from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.utils import enforce_types +from letta.validators import raise_on_invalid_id class ProviderManager: @@ -40,6 +41,7 @@ class ProviderManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER) async def update_provider_async(self, provider_id: str, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider: """Update provider details.""" async with db_registry.async_session() as session: @@ -103,6 +105,7 @@ class ProviderManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER) async def delete_provider_by_id_async(self, provider_id: str, actor: PydanticUser): """Delete a provider.""" async with db_registry.async_session() as session: @@ -153,6 +156,7 @@ class ProviderManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER) async def get_provider_async(self, provider_id: str, actor: PydanticUser) -> PydanticProvider: async with db_registry.async_session() as session: provider_model = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=actor) diff --git a/letta/services/run_manager.py b/letta/services/run_manager.py index d3d70d9e..33677f4b 100644 --- a/letta/services/run_manager.py +++ b/letta/services/run_manager.py @@ -14,7 +14,7 @@ from letta.orm.run_metrics import RunMetrics as RunMetricsModel from letta.orm.sqlalchemy_base import AccessType from letta.orm.step import Step as StepModel from letta.otel.tracing import log_event, trace_method -from letta.schemas.enums import AgentType, ComparisonOperator, MessageRole, RunStatus +from letta.schemas.enums import AgentType, ComparisonOperator, MessageRole, RunStatus, PrimitiveType from letta.schemas.job import LettaRequestConfig from letta.schemas.letta_message import LettaMessage, LettaMessageUnion from letta.schemas.letta_response import LettaResponse @@ -31,6 +31,7 @@ from letta.services.helpers.agent_manager_helper import validate_agent_exists_as from letta.services.message_manager import MessageManager from letta.services.step_manager import StepManager from letta.utils import enforce_types +from letta.validators import raise_on_invalid_id logger = get_logger(__name__) @@ -85,6 +86,7 @@ class RunManager: return run.to_pydantic() @enforce_types + @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) async def get_run_by_id(self, run_id: str, actor: PydanticUser) -> PydanticRun: """Get a run by its ID.""" async with db_registry.async_session() as session: @@ -176,6 +178,7 @@ class RunManager: return [run.to_pydantic() for run in runs] @enforce_types + @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) async def delete_run(self, run_id: str, actor: PydanticUser) -> PydanticRun: """Delete a run by its ID.""" async with db_registry.async_session() as session: @@ -189,11 +192,11 @@ class RunManager: return pydantic_run @enforce_types + @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) async def update_run_by_id_async( self, run_id: str, update: RunUpdate, actor: PydanticUser, refresh_result_messages: bool = True ) -> PydanticRun: """Update a run using a RunUpdate object.""" - async with db_registry.async_session() as session: run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor) @@ -327,6 +330,7 @@ class RunManager: return result @enforce_types + @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) async def get_run_usage(self, run_id: str, actor: PydanticUser) -> LettaUsageStatistics: """Get usage statistics for a run.""" async with db_registry.async_session() as session: @@ -344,6 +348,7 @@ class RunManager: return total_usage @enforce_types + @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) async def get_run_messages( self, run_id: str, @@ -378,6 +383,7 @@ class RunManager: return letta_messages @enforce_types + @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) async def get_run_request_config(self, run_id: str, actor: PydanticUser) -> Optional[LettaRequestConfig]: """Get the letta request config from a run.""" async with db_registry.async_session() as session: @@ -388,6 +394,7 @@ class RunManager: return pydantic_run.request_config @enforce_types + @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) async def get_run_metrics_async(self, run_id: str, actor: PydanticUser) -> PydanticRunMetrics: """Get metrics for a run.""" async with db_registry.async_session() as session: @@ -395,6 +402,7 @@ class RunManager: return metrics.to_pydantic() @enforce_types + @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) async def get_run_steps( self, run_id: str, diff --git a/letta/services/sandbox_config_manager.py b/letta/services/sandbox_config_manager.py index ad4a7563..2f6191ac 100644 --- a/letta/services/sandbox_config_manager.py +++ b/letta/services/sandbox_config_manager.py @@ -5,7 +5,7 @@ from letta.log import get_logger from letta.orm.errors import NoResultFound from letta.orm.sandbox_config import SandboxConfig as SandboxConfigModel, SandboxEnvironmentVariable as SandboxEnvVarModel from letta.otel.tracing import trace_method -from letta.schemas.enums import SandboxType +from letta.schemas.enums import PrimitiveType, SandboxType from letta.schemas.environment_variables import ( SandboxEnvironmentVariable as PydanticEnvVar, SandboxEnvironmentVariableCreate, @@ -20,6 +20,7 @@ from letta.schemas.sandbox_config import ( from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.utils import enforce_types, printd +from letta.validators import raise_on_invalid_id logger = get_logger(__name__) @@ -101,6 +102,7 @@ class SandboxConfigManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG) async def update_sandbox_config_async( self, sandbox_config_id: str, sandbox_update: SandboxConfigUpdate, actor: PydanticUser ) -> PydanticSandboxConfig: @@ -129,6 +131,7 @@ class SandboxConfigManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG) async def delete_sandbox_config_async(self, sandbox_config_id: str, actor: PydanticUser) -> PydanticSandboxConfig: """Delete a sandbox configuration by its ID.""" async with db_registry.async_session() as session: @@ -176,6 +179,7 @@ class SandboxConfigManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG) async def create_sandbox_env_var_async( self, env_var_create: SandboxEnvironmentVariableCreate, sandbox_config_id: str, actor: PydanticUser ) -> PydanticEnvVar: @@ -266,6 +270,7 @@ class SandboxConfigManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG) async def list_sandbox_env_vars_async( self, sandbox_config_id: str, @@ -302,6 +307,7 @@ class SandboxConfigManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG) def get_sandbox_env_vars_as_dict( self, sandbox_config_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50 ) -> Dict[str, str]: @@ -315,6 +321,7 @@ class SandboxConfigManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG) async def get_sandbox_env_vars_as_dict_async( self, sandbox_config_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50 ) -> Dict[str, str]: @@ -324,6 +331,7 @@ class SandboxConfigManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG) async def get_sandbox_env_var_by_key_and_sandbox_config_id_async( self, key: str, sandbox_config_id: str, actor: Optional[PydanticUser] = None ) -> Optional[PydanticEnvVar]: diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index d738e4f1..9ef2cb6f 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -11,11 +11,12 @@ from letta.orm.source import Source as SourceModel from letta.orm.sources_agents import SourcesAgents from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState as PydanticAgentState -from letta.schemas.enums import VectorDBProvider +from letta.schemas.enums import PrimitiveType, VectorDBProvider from letta.schemas.source import Source as PydanticSource, SourceUpdate from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.utils import enforce_types, printd +from letta.validators import raise_on_invalid_id class SourceManager: @@ -201,6 +202,7 @@ class SourceManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE) async def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource: """Update a source by its ID with the given SourceUpdate object.""" async with db_registry.async_session() as session: @@ -224,6 +226,7 @@ class SourceManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE) async def delete_source(self, source_id: str, actor: PydanticUser) -> PydanticSource: """Delete a source by its ID.""" async with db_registry.async_session() as session: @@ -268,6 +271,7 @@ class SourceManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE) async def list_attached_agents( self, source_id: str, actor: PydanticUser, ids_only: bool = False ) -> Union[List[PydanticAgentState], List[str]]: @@ -321,6 +325,7 @@ class SourceManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE) async def get_agents_for_source_id(self, source_id: str, actor: PydanticUser) -> List[str]: """ Get all agent IDs associated with a given source ID. @@ -347,6 +352,7 @@ class SourceManager: # TODO: We make actor optional for now, but should most likely be enforced due to security reasons @enforce_types @trace_method + @raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE) async def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]: """Retrieve a source by its ID.""" async with db_registry.async_session() as session: diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index 3609e11f..a74d6614 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -13,7 +13,7 @@ from letta.orm.sqlalchemy_base import AccessType from letta.orm.step import Step as StepModel from letta.orm.step_metrics import StepMetrics as StepMetricsModel from letta.otel.tracing import get_trace_id, trace_method -from letta.schemas.enums import StepStatus +from letta.schemas.enums import PrimitiveType, StepStatus from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_response import UsageStatistics @@ -22,6 +22,7 @@ from letta.schemas.step_metrics import StepMetrics as PydanticStepMetrics from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.utils import enforce_types +from letta.validators import raise_on_invalid_id class FeedbackType(str, Enum): @@ -32,6 +33,8 @@ class FeedbackType(str, Enum): class StepManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) + @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) async def list_steps_async( self, actor: PydanticUser, @@ -79,6 +82,10 @@ class StepManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) + @raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER) + @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) + @raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP) def log_step( self, actor: PydanticUser, @@ -133,6 +140,10 @@ class StepManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) + @raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER) + @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) + @raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP) async def log_step_async( self, actor: PydanticUser, @@ -196,6 +207,7 @@ class StepManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP) async def get_step_async(self, step_id: str, actor: PydanticUser) -> PydanticStep: async with db_registry.async_session() as session: step = await StepModel.read_async(db_session=session, identifier=step_id, actor=actor) @@ -203,6 +215,7 @@ class StepManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP) async def get_step_metrics_async(self, step_id: str, actor: PydanticUser) -> PydanticStepMetrics: async with db_registry.async_session() as session: metrics = await StepMetricsModel.read_async(db_session=session, identifier=step_id, actor=actor) @@ -210,6 +223,7 @@ class StepManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP) async def add_feedback_async( self, step_id: str, feedback: FeedbackType | None, actor: PydanticUser, tags: list[str] | None = None ) -> PydanticStep: @@ -225,6 +239,7 @@ class StepManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP) async def update_step_transaction_id(self, actor: PydanticUser, step_id: str, transaction_id: str) -> PydanticStep: """Update the transaction ID for a step. @@ -252,6 +267,7 @@ class StepManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP) async def list_step_messages_async( self, step_id: str, @@ -276,6 +292,7 @@ class StepManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP) async def update_step_stop_reason(self, actor: PydanticUser, step_id: str, stop_reason: StopReasonType) -> PydanticStep: """Update the stop reason for a step. @@ -303,6 +320,7 @@ class StepManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP) async def update_step_error_async( self, actor: PydanticUser, @@ -348,6 +366,7 @@ class StepManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP) async def update_step_success_async( self, actor: PydanticUser, @@ -388,6 +407,7 @@ class StepManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP) async def update_step_cancelled_async( self, actor: PydanticUser, @@ -423,6 +443,9 @@ class StepManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP) + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) + @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) async def record_step_metrics_async( self, actor: PydanticUser, diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 7bf17514..7df4c9c0 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -27,7 +27,7 @@ from letta.log import get_logger from letta.orm.errors import NoResultFound from letta.orm.tool import Tool as ToolModel from letta.otel.tracing import trace_method -from letta.schemas.enums import ToolType +from letta.schemas.enums import PrimitiveType, ToolType from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry @@ -36,6 +36,7 @@ from letta.services.mcp.types import SSEServerConfig, StdioServerConfig from letta.services.tool_schema_generator import generate_schema_for_tool_creation, generate_schema_for_tool_update from letta.settings import settings from letta.utils import enforce_types, printd +from letta.validators import raise_on_invalid_id logger = get_logger(__name__) @@ -203,6 +204,7 @@ class ToolManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="tool_id", expected_prefix=PrimitiveType.TOOL) async def get_tool_by_id_async(self, tool_id: str, actor: PydanticUser) -> PydanticTool: """Fetch a tool by its ID.""" async with db_registry.async_session() as session: @@ -235,6 +237,7 @@ class ToolManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="tool_id", expected_prefix=PrimitiveType.TOOL) async def tool_exists_async(self, tool_id: str, actor: PydanticUser) -> bool: """Check if a tool exists and belongs to the user's organization (lightweight check).""" async with db_registry.async_session() as session: @@ -507,6 +510,7 @@ class ToolManager: @enforce_types @trace_method + @raise_on_invalid_id(param_name="tool_id", expected_prefix=PrimitiveType.TOOL) async def update_tool_by_id_async( self, tool_id: str, @@ -604,6 +608,7 @@ class ToolManager: @enforce_types @trace_method + # @raise_on_invalid_id This is commented out bc it's called by _list_tools_async, when it encounters malformed tools (i.e. if id is invalid will fail validation on deletion) async def delete_tool_by_id_async(self, tool_id: str, actor: PydanticUser) -> None: """Delete a tool by its ID.""" async with db_registry.async_session() as session: diff --git a/letta/validators.py b/letta/validators.py index 5f297946..8eb75a2e 100644 --- a/letta/validators.py +++ b/letta/validators.py @@ -1,48 +1,21 @@ +import inspect import re +from functools import wraps from typing import Annotated from fastapi import Path -from letta.schemas.agent import AgentState -from letta.schemas.archive import ArchiveBase -from letta.schemas.block import BaseBlock -from letta.schemas.file import FileMetadataBase -from letta.schemas.folder import BaseFolder -from letta.schemas.group import GroupBase -from letta.schemas.identity import IdentityBase -from letta.schemas.job import JobBase -from letta.schemas.message import BaseMessage -from letta.schemas.providers import ProviderBase -from letta.schemas.run import RunBase -from letta.schemas.sandbox_config import SandboxConfigBase -from letta.schemas.source import BaseSource -from letta.schemas.step import StepBase -from letta.schemas.tool import BaseTool +from letta.errors import LettaInvalidArgumentError +from letta.schemas.enums import PrimitiveType # PrimitiveType is now in schemas.enums -# TODO: extract this list from routers/v1/__init__.py and ROUTERS -primitives = [ - AgentState.__id_prefix__, - BaseMessage.__id_prefix__, - RunBase.__id_prefix__, - JobBase.__id_prefix__, - GroupBase.__id_prefix__, - BaseBlock.__id_prefix__, - FileMetadataBase.__id_prefix__, - BaseFolder.__id_prefix__, - BaseSource.__id_prefix__, - BaseTool.__id_prefix__, - ArchiveBase.__id_prefix__, - ProviderBase.__id_prefix__, - SandboxConfigBase.__id_prefix__, - StepBase.__id_prefix__, - IdentityBase.__id_prefix__, -] +# Map from PrimitiveType to the actual prefix string (which is just the enum value) +PRIMITIVE_ID_PREFIXES = {primitive_type: primitive_type.value for primitive_type in PrimitiveType} PRIMITIVE_ID_PATTERNS = { # f-string interpolation gets confused because of the regex's required curly braces {} - primitive: re.compile("^" + primitive + "-[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$") - for primitive in primitives + prefix: re.compile("^" + prefix + "-[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$") + for prefix in PRIMITIVE_ID_PREFIXES.values() } @@ -67,28 +40,65 @@ def _create_path_validator_factory(primitive: str): # PATH_VALIDATORS now contains factory functions, not Path objects -# Usage: folder_id: str = PATH_VALIDATORS[BaseFolder.__id_prefix__]() -PATH_VALIDATORS = {primitive: _create_path_validator_factory(primitive) for primitive in primitives} - - -def is_valid_id(primitive: str, id: str) -> bool: - return PRIMITIVE_ID_PATTERNS[primitive].match(id) is not None +# Usage: folder_id: str = PATH_VALIDATORS[PrimitiveType.FOLDER.value]() +PATH_VALIDATORS = {primitive_type.value: _create_path_validator_factory(primitive_type.value) for primitive_type in PrimitiveType} # Type aliases for common ID types # These can be used directly in route handler signatures for cleaner code -AgentId = Annotated[str, PATH_VALIDATORS[AgentState.__id_prefix__]()] -ToolId = Annotated[str, PATH_VALIDATORS[BaseTool.__id_prefix__]()] -SourceId = Annotated[str, PATH_VALIDATORS[BaseSource.__id_prefix__]()] -BlockId = Annotated[str, PATH_VALIDATORS[BaseBlock.__id_prefix__]()] -MessageId = Annotated[str, PATH_VALIDATORS[BaseMessage.__id_prefix__]()] -RunId = Annotated[str, PATH_VALIDATORS[RunBase.__id_prefix__]()] -JobId = Annotated[str, PATH_VALIDATORS[JobBase.__id_prefix__]()] -GroupId = Annotated[str, PATH_VALIDATORS[GroupBase.__id_prefix__]()] -FileId = Annotated[str, PATH_VALIDATORS[FileMetadataBase.__id_prefix__]()] -FolderId = Annotated[str, PATH_VALIDATORS[BaseFolder.__id_prefix__]()] -ArchiveId = Annotated[str, PATH_VALIDATORS[ArchiveBase.__id_prefix__]()] -ProviderId = Annotated[str, PATH_VALIDATORS[ProviderBase.__id_prefix__]()] -SandboxConfigId = Annotated[str, PATH_VALIDATORS[SandboxConfigBase.__id_prefix__]()] -StepId = Annotated[str, PATH_VALIDATORS[StepBase.__id_prefix__]()] -IdentityId = Annotated[str, PATH_VALIDATORS[IdentityBase.__id_prefix__]()] +AgentId = Annotated[str, PATH_VALIDATORS[PrimitiveType.AGENT.value]()] +ToolId = Annotated[str, PATH_VALIDATORS[PrimitiveType.TOOL.value]()] +SourceId = Annotated[str, PATH_VALIDATORS[PrimitiveType.SOURCE.value]()] +BlockId = Annotated[str, PATH_VALIDATORS[PrimitiveType.BLOCK.value]()] +MessageId = Annotated[str, PATH_VALIDATORS[PrimitiveType.MESSAGE.value]()] +RunId = Annotated[str, PATH_VALIDATORS[PrimitiveType.RUN.value]()] +JobId = Annotated[str, PATH_VALIDATORS[PrimitiveType.JOB.value]()] +GroupId = Annotated[str, PATH_VALIDATORS[PrimitiveType.GROUP.value]()] +FileId = Annotated[str, PATH_VALIDATORS[PrimitiveType.FILE.value]()] +FolderId = Annotated[str, PATH_VALIDATORS[PrimitiveType.FOLDER.value]()] +ArchiveId = Annotated[str, PATH_VALIDATORS[PrimitiveType.ARCHIVE.value]()] +ProviderId = Annotated[str, PATH_VALIDATORS[PrimitiveType.PROVIDER.value]()] +SandboxConfigId = Annotated[str, PATH_VALIDATORS[PrimitiveType.SANDBOX_CONFIG.value]()] +StepId = Annotated[str, PATH_VALIDATORS[PrimitiveType.STEP.value]()] +IdentityId = Annotated[str, PATH_VALIDATORS[PrimitiveType.IDENTITY.value]()] + + +def raise_on_invalid_id(param_name: str, expected_prefix: PrimitiveType): + """ + Decorator that validates an ID parameter has the expected prefix format. + Can be stacked multiple times on the same function to validate different IDs. + + Args: + param_name: The name of the function parameter to validate (e.g., "agent_id") + expected_prefix: The expected primitive type (e.g., PrimitiveType.AGENT) + + Example: + @raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT) + @raise_on_invalid_id(param_name="folder_id", expected_prefix=PrimitiveType.FOLDER) + def my_function(agent_id: str, folder_id: str): + pass + """ + + def decorator(function): + @wraps(function) + def wrapper(*args, **kwargs): + sig = inspect.signature(function) + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + + if param_name in bound_args.arguments: + arg_value = bound_args.arguments[param_name] + + if arg_value is not None: + prefix = PRIMITIVE_ID_PREFIXES[expected_prefix] + if PRIMITIVE_ID_PATTERNS[prefix].match(arg_value) is None: + raise LettaInvalidArgumentError( + message=f"Invalid {expected_prefix.value} ID format: {arg_value}. Expected format: '{prefix}-'", + argument_name=param_name, + ) + + return function(*args, **kwargs) + + return wrapper + + return decorator diff --git a/tests/integration_test_turbopuffer.py b/tests/integration_test_turbopuffer.py index 0ac684c1..f642e75e 100644 --- a/tests/integration_test_turbopuffer.py +++ b/tests/integration_test_turbopuffer.py @@ -286,7 +286,7 @@ async def test_turbopuffer_metadata_attributes(default_user, enable_turbopuffer) pytest.skip("No Turbopuffer API key available") client = TurbopufferClient() - archive_id = f"test-archive-{datetime.now().timestamp()}" + archive_id = f"archive-{uuid.uuid4()}" try: # Insert passages with various metadata @@ -391,7 +391,7 @@ async def test_hybrid_search_with_real_tpuf(default_user, enable_turbopuffer): from letta.helpers.tpuf_client import TurbopufferClient client = TurbopufferClient() - archive_id = f"test-hybrid-{datetime.now().timestamp()}" + archive_id = f"archive-{uuid.uuid4()}" org_id = str(uuid.uuid4()) try: @@ -497,7 +497,7 @@ async def test_tag_filtering_with_real_tpuf(default_user, enable_turbopuffer): from letta.helpers.tpuf_client import TurbopufferClient client = TurbopufferClient() - archive_id = f"test-tags-{datetime.now().timestamp()}" + archive_id = f"archive-{uuid.uuid4()}" org_id = str(uuid.uuid4()) try: @@ -628,7 +628,7 @@ async def test_temporal_filtering_with_real_tpuf(default_user, enable_turbopuffe client = TurbopufferClient() # Create a unique archive ID for this test - archive_id = f"test-temporal-{uuid.uuid4()}" + archive_id = f"archive-{uuid.uuid4()}" try: # Create passages with different timestamps diff --git a/tests/managers/test_job_manager.py b/tests/managers/test_job_manager.py index 58c48df0..9c0d8d59 100644 --- a/tests/managers/test_job_manager.py +++ b/tests/managers/test_job_manager.py @@ -44,7 +44,7 @@ from letta.constants import ( MULTI_AGENT_TOOLS, ) from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client -from letta.errors import LettaAgentNotFoundError +from letta.errors import LettaAgentNotFoundError, LettaInvalidArgumentError from letta.functions.functions import derive_openai_json_schema, parse_source_code from letta.functions.mcp_client.types import MCPTool from letta.helpers import ToolRulesSolver @@ -233,7 +233,7 @@ async def test_update_job_auto_complete(server: SyncServer, default_user): async def test_get_job_not_found(server: SyncServer, default_user): """Test fetching a non-existent job.""" non_existent_job_id = "nonexistent-id" - with pytest.raises(NoResultFound): + with pytest.raises(LettaInvalidArgumentError): await server.job_manager.get_job_by_id_async(non_existent_job_id, actor=default_user) @@ -241,7 +241,7 @@ async def test_get_job_not_found(server: SyncServer, default_user): async def test_delete_job_not_found(server: SyncServer, default_user): """Test deleting a non-existent job.""" non_existent_job_id = "nonexistent-id" - with pytest.raises(NoResultFound): + with pytest.raises(LettaInvalidArgumentError): await server.job_manager.delete_job_by_id_async(non_existent_job_id, actor=default_user) @@ -412,9 +412,6 @@ async def test_e2e_job_callback(monkeypatch, server: SyncServer, default_user): # ====================================================================================================================== - - - @pytest.mark.asyncio async def test_record_ttft(server: SyncServer, default_user): """Test recording time to first token for a job.""" @@ -478,9 +475,11 @@ async def test_record_timing_metrics_together(server: SyncServer, default_user): @pytest.mark.asyncio async def test_record_timing_invalid_job(server: SyncServer, default_user): - """Test recording timing metrics for non-existent job fails gracefully.""" - # Try to record TTFT for non-existent job - should not raise exception but log warning - await server.job_manager.record_ttft("nonexistent_job_id", 1_000_000_000, default_user) + """Test recording timing metrics for non-existent job raises LettaInvalidArgumentError.""" + # Try to record TTFT for non-existent job - should raise LettaInvalidArgumentError + with pytest.raises(LettaInvalidArgumentError): + await server.job_manager.record_ttft("nonexistent_job_id", 1_000_000_000, default_user) - # Try to record response duration for non-existent job - should not raise exception but log warning - await server.job_manager.record_response_duration("nonexistent_job_id", 2_000_000_000, default_user) + # Try to record response duration for non-existent job - should raise LettaInvalidArgumentError + with pytest.raises(LettaInvalidArgumentError): + await server.job_manager.record_response_duration("nonexistent_job_id", 2_000_000_000, default_user) diff --git a/tests/managers/test_run_manager.py b/tests/managers/test_run_manager.py index e61de6cc..77cc574c 100644 --- a/tests/managers/test_run_manager.py +++ b/tests/managers/test_run_manager.py @@ -44,7 +44,7 @@ from letta.constants import ( MULTI_AGENT_TOOLS, ) from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client -from letta.errors import LettaAgentNotFoundError +from letta.errors import LettaAgentNotFoundError, LettaInvalidArgumentError from letta.functions.functions import derive_openai_json_schema, parse_source_code from letta.functions.mcp_client.types import MCPTool from letta.helpers import ToolRulesSolver @@ -241,7 +241,7 @@ async def test_update_run_auto_complete(server: SyncServer, default_user, sarah_ async def test_get_run_not_found(server: SyncServer, default_user): """Test fetching a non-existent run.""" non_existent_run_id = "nonexistent-id" - with pytest.raises(NoResultFound): + with pytest.raises(LettaInvalidArgumentError): await server.run_manager.get_run_by_id(non_existent_run_id, actor=default_user) @@ -249,7 +249,7 @@ async def test_get_run_not_found(server: SyncServer, default_user): async def test_delete_run_not_found(server: SyncServer, default_user): """Test deleting a non-existent run.""" non_existent_run_id = "nonexistent-id" - with pytest.raises(NoResultFound): + with pytest.raises(LettaInvalidArgumentError): await server.run_manager.delete_run(non_existent_run_id, actor=default_user) @@ -1268,7 +1268,7 @@ async def test_run_usage_stats_get_nonexistent_run(server: SyncServer, default_u """Test getting usage statistics for a nonexistent run.""" run_manager = server.run_manager - with pytest.raises(NoResultFound): + with pytest.raises(LettaInvalidArgumentError): await run_manager.get_run_usage(run_id="nonexistent_run", actor=default_user) @@ -1307,7 +1307,7 @@ async def test_get_run_request_config_none(server: SyncServer, sarah_agent, defa @pytest.mark.asyncio async def test_get_run_request_config_nonexistent_run(server: SyncServer, default_user): """Test getting request config for a nonexistent run.""" - with pytest.raises(NoResultFound): + with pytest.raises(LettaInvalidArgumentError): await server.run_manager.get_run_request_config("nonexistent_run", actor=default_user) @@ -1453,7 +1453,7 @@ async def test_run_metrics_num_steps_tracking(server: SyncServer, sarah_agent, d @pytest.mark.asyncio async def test_run_metrics_not_found(server: SyncServer, default_user): """Test getting metrics for non-existent run.""" - with pytest.raises(NoResultFound): + with pytest.raises(LettaInvalidArgumentError): await server.run_manager.get_run_metrics_async(run_id="nonexistent_run", actor=default_user) diff --git a/tests/managers/test_source_manager.py b/tests/managers/test_source_manager.py index dc45d7c5..509567e5 100644 --- a/tests/managers/test_source_manager.py +++ b/tests/managers/test_source_manager.py @@ -44,7 +44,7 @@ from letta.constants import ( MULTI_AGENT_TOOLS, ) from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client -from letta.errors import LettaAgentNotFoundError +from letta.errors import LettaAgentNotFoundError, LettaInvalidArgumentError from letta.functions.functions import derive_openai_json_schema, parse_source_code from letta.functions.mcp_client.types import MCPTool from letta.helpers import ToolRulesSolver @@ -174,21 +174,21 @@ async def test_detach_source(server: SyncServer, sarah_agent, default_source, de async def test_attach_source_nonexistent_agent(server: SyncServer, default_source, default_user): """Test attaching a source to a nonexistent agent.""" with pytest.raises(NoResultFound): - await server.agent_manager.attach_source_async(agent_id="nonexistent-agent-id", source_id=default_source.id, actor=default_user) + await server.agent_manager.attach_source_async(agent_id=f"agent-{uuid.uuid4()}", source_id=default_source.id, actor=default_user) @pytest.mark.asyncio async def test_attach_source_nonexistent_source(server: SyncServer, sarah_agent, default_user): """Test attaching a nonexistent source to an agent.""" with pytest.raises(NoResultFound): - await server.agent_manager.attach_source_async(agent_id=sarah_agent.id, source_id="nonexistent-source-id", actor=default_user) + await server.agent_manager.attach_source_async(agent_id=sarah_agent.id, source_id=f"source-{uuid.uuid4()}", actor=default_user) @pytest.mark.asyncio async def test_detach_source_nonexistent_agent(server: SyncServer, default_source, default_user): """Test detaching a source from a nonexistent agent.""" with pytest.raises(LettaAgentNotFoundError): - await server.agent_manager.detach_source_async(agent_id="nonexistent-agent-id", source_id=default_source.id, actor=default_user) + await server.agent_manager.detach_source_async(agent_id=f"agent-{uuid.uuid4()}", source_id=default_source.id, actor=default_user) @pytest.mark.asyncio @@ -234,7 +234,7 @@ async def test_list_attached_agents(server: SyncServer, sarah_agent, charles_age async def test_list_attached_agents_nonexistent_source(server: SyncServer, default_user): """Test listing agents for a nonexistent source.""" - with pytest.raises(NoResultFound): + with pytest.raises(LettaInvalidArgumentError): await server.source_manager.list_attached_agents(source_id="nonexistent-source-id", actor=default_user) diff --git a/tests/managers/test_tool_manager.py b/tests/managers/test_tool_manager.py index 6f041f50..be87e01b 100644 --- a/tests/managers/test_tool_manager.py +++ b/tests/managers/test_tool_manager.py @@ -220,7 +220,7 @@ async def test_bulk_detach_tools_idempotent(server: SyncServer, sarah_agent, pri @pytest.mark.asyncio async def test_bulk_detach_tools_nonexistent_agent(server: SyncServer, print_tool, other_tool, default_user): """Test bulk detaching tools from a nonexistent agent.""" - nonexistent_agent_id = "nonexistent-agent-id" + nonexistent_agent_id = f"agent-{uuid.uuid4()}" tool_ids = [print_tool.id, other_tool.id] with pytest.raises(LettaAgentNotFoundError): @@ -230,19 +230,19 @@ async def test_bulk_detach_tools_nonexistent_agent(server: SyncServer, print_too async def test_attach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user): """Test attaching a tool to a nonexistent agent.""" with pytest.raises(LettaAgentNotFoundError): - await server.agent_manager.attach_tool_async(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user) + await server.agent_manager.attach_tool_async(agent_id=f"agent-{uuid.uuid4()}", tool_id=print_tool.id, actor=default_user) async def test_attach_tool_nonexistent_tool(server: SyncServer, sarah_agent, default_user): """Test attaching a nonexistent tool to an agent.""" with pytest.raises(NoResultFound): - await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id="nonexistent-tool-id", actor=default_user) + await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=f"tool-{uuid.uuid4()}", actor=default_user) async def test_detach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user): """Test detaching a tool from a nonexistent agent.""" with pytest.raises(LettaAgentNotFoundError): - await server.agent_manager.detach_tool_async(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user) + await server.agent_manager.detach_tool_async(agent_id=f"agent-{uuid.uuid4()}", tool_id=print_tool.id, actor=default_user) @pytest.mark.asyncio @@ -330,7 +330,7 @@ async def test_bulk_attach_tools_nonexistent_tool(server: SyncServer, sarah_agen @pytest.mark.asyncio async def test_bulk_attach_tools_nonexistent_agent(server: SyncServer, print_tool, other_tool, default_user): """Test bulk attaching tools to a nonexistent agent.""" - nonexistent_agent_id = "nonexistent-agent-id" + nonexistent_agent_id = f"agent-{uuid.uuid4()}" tool_ids = [print_tool.id, other_tool.id] with pytest.raises(LettaAgentNotFoundError): @@ -2187,7 +2187,8 @@ async def test_list_tools_with_corrupted_tool(server: SyncServer, default_user, from letta.orm.tool import Tool as ToolModel async with db_registry.async_session() as session: - # Create a tool with no json_schema (corrupted state) + # Create a tool with corrupted ID format (bypassing validation) + # This simulates a tool that somehow got corrupted in the database corrupted_tool = ToolModel( id=f"tool-corrupted-{uuid.uuid4()}", name="corrupted_tool",