feat: runtime validation for ids for internal managers calls (#5544)

* claude coded first pass

* fix test cases to expect errors instead

* fix this

* let's see how letta-code did

* claude

* fix tests, remove dangling comments, retrofit all managers functions with decorator

* revert to main for these since we are not erroring on invalid tool and block ids

* reorder decorators

* finish refactoring test cases

* reorder agent_manager decorators and fix test tool manager

* add decorator on missing managers

* fix id sources

* remove redundant check

* uses enum now

* move to enum
This commit is contained in:
Kian Jones
2025-10-22 15:00:41 -07:00
committed by Caren Thomas
parent 0a083459c6
commit 45065297a0
37 changed files with 312 additions and 134 deletions

View File

@@ -8,6 +8,7 @@ from letta.constants import CORE_MEMORY_LINE_NUMBER_WARNING, DEFAULT_EMBEDDING_C
from letta.errors import AgentExportProcessingError
from letta.schemas.block import Block, CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import PrimitiveType
from letta.schemas.environment_variables import AgentEnvironmentVariable
from letta.schemas.file import FileStatus
from letta.schemas.group import Group
@@ -57,7 +58,7 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
embedding_config (EmbeddingConfig): The embedding configuration used by the agent.
"""
__id_prefix__ = "agent"
__id_prefix__ = PrimitiveType.AGENT.value
# NOTE: this is what is returned to the client and also what is used to initialize `Agent`
id: str = Field(..., description="The id of the agent. Assigned by the database.")

View File

@@ -3,12 +3,12 @@ from typing import Dict, Optional
from pydantic import Field
from letta.schemas.enums import VectorDBProvider
from letta.schemas.enums import PrimitiveType, VectorDBProvider
from letta.schemas.letta_base import OrmMetadataBase
class ArchiveBase(OrmMetadataBase):
__id_prefix__ = "archive"
__id_prefix__ = PrimitiveType.ARCHIVE.value
name: str = Field(..., description="The name of the archive")
description: Optional[str] = Field(None, description="A description of the archive")

View File

@@ -4,6 +4,7 @@ from typing import Any, Optional
from pydantic import ConfigDict, Field, model_validator
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT, DEFAULT_HUMAN_BLOCK_DESCRIPTION, DEFAULT_PERSONA_BLOCK_DESCRIPTION
from letta.schemas.enums import PrimitiveType
from letta.schemas.letta_base import LettaBase
# block of the LLM context
@@ -12,7 +13,7 @@ from letta.schemas.letta_base import LettaBase
class BaseBlock(LettaBase, validate_assignment=True):
"""Base block of the LLM context"""
__id_prefix__ = "block"
__id_prefix__ = PrimitiveType.BLOCK.value
# data value
value: str = Field(..., description="Value of the block.")

View File

@@ -1,6 +1,31 @@
from enum import Enum, StrEnum
class PrimitiveType(str, Enum):
"""
Enum for all primitive resource types in Letta.
The enum values ARE the actual ID prefixes used in the system.
This serves as the single source of truth for all ID prefixes.
"""
AGENT = "agent"
MESSAGE = "message"
RUN = "run"
JOB = "job"
GROUP = "group"
BLOCK = "block"
FILE = "file"
FOLDER = "source" # Note: folder IDs use "source" prefix for historical reasons
SOURCE = "source"
TOOL = "tool"
ARCHIVE = "archive"
PROVIDER = "provider"
SANDBOX_CONFIG = "sandbox" # Note: sandbox_config IDs use "sandbox" prefix
STEP = "step"
IDENTITY = "identity"
class ProviderType(str, Enum):
anthropic = "anthropic"
azure = "azure"

View File

@@ -4,7 +4,7 @@ from typing import List, Optional
from pydantic import Field
from letta.schemas.enums import FileProcessingStatus
from letta.schemas.enums import FileProcessingStatus, PrimitiveType
from letta.schemas.letta_base import LettaBase
@@ -20,7 +20,7 @@ class FileStatus(str, Enum):
class FileMetadataBase(LettaBase):
"""Base class for FileMetadata schemas"""
__id_prefix__ = "file"
__id_prefix__ = PrimitiveType.FILE.value
# Core file metadata fields
source_id: str = Field(..., description="The unique identifier of the source associated with the document.")
@@ -61,7 +61,7 @@ class FileMetadata(FileMetadataBase):
class FileAgentBase(LettaBase):
"""Base class for the FileMetadata-⇄-Agent association schemas"""
__id_prefix__ = "file_agent"
__id_prefix__ = PrimitiveType.FILE.value
# Core file-agent association fields
agent_id: str = Field(..., description="Unique identifier of the agent.")

View File

@@ -4,6 +4,7 @@ from typing import Optional
from pydantic import Field
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import PrimitiveType
from letta.schemas.letta_base import LettaBase
@@ -12,7 +13,7 @@ class BaseFolder(LettaBase):
Shared attributes across all folder schemas.
"""
__id_prefix__ = "source" # TODO: change to "folder"
__id_prefix__ = PrimitiveType.FOLDER.value # TODO: change to "folder"
# Core folder fields
name: str = Field(..., description="The name of the folder.")

View File

@@ -3,6 +3,7 @@ from typing import Annotated, List, Literal, Optional, Union
from pydantic import BaseModel, Field
from letta.schemas.enums import PrimitiveType
from letta.schemas.letta_base import LettaBase
@@ -20,7 +21,7 @@ class ManagerConfig(BaseModel):
class GroupBase(LettaBase):
__id_prefix__ = "group"
__id_prefix__ = PrimitiveType.GROUP.value
class Group(GroupBase):

View File

@@ -3,6 +3,7 @@ from typing import List, Optional, Union
from pydantic import Field
from letta.schemas.enums import PrimitiveType
from letta.schemas.letta_base import LettaBase
@@ -28,7 +29,7 @@ class IdentityPropertyType(str, Enum):
class IdentityBase(LettaBase):
__id_prefix__ = "identity"
__id_prefix__ = PrimitiveType.IDENTITY.value
class IdentityProperty(LettaBase):

View File

@@ -3,6 +3,8 @@ from typing import TYPE_CHECKING, List, Optional
from pydantic import BaseModel, ConfigDict, Field
from letta.schemas.enums import PrimitiveType
if TYPE_CHECKING:
from letta.schemas.letta_request import LettaRequest
@@ -15,7 +17,7 @@ from letta.schemas.letta_stop_reason import StopReasonType
class JobBase(OrmMetadataBase):
__id_prefix__ = "job"
__id_prefix__ = PrimitiveType.JOB.value
status: JobStatus = Field(default=JobStatus.created, description="The status of the job.")
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the job was created.")

View File

@@ -19,7 +19,7 @@ from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, RE
from letta.helpers.datetime_helpers import get_utc_time, is_utc_datetime
from letta.helpers.json_helpers import json_dumps
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_VERTEX
from letta.schemas.enums import MessageRole
from letta.schemas.enums import MessageRole, PrimitiveType
from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.letta_message import (
ApprovalRequestMessage,
@@ -170,7 +170,7 @@ class MessageUpdate(BaseModel):
class BaseMessage(OrmMetadataBase):
__id_prefix__ = "message"
__id_prefix__ = PrimitiveType.MESSAGE.value
class Message(BaseMessage):

View File

@@ -4,7 +4,7 @@ from pydantic import BaseModel, Field, model_validator
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.enums import PrimitiveType, ProviderCategory, ProviderType
from letta.schemas.letta_base import LettaBase
from letta.schemas.llm_config import LLMConfig
from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
@@ -13,7 +13,7 @@ from letta.settings import model_settings
class ProviderBase(LettaBase):
__id_prefix__ = "provider"
__id_prefix__ = PrimitiveType.PROVIDER.value
class Provider(ProviderBase):

View File

@@ -4,14 +4,14 @@ from typing import Optional
from pydantic import ConfigDict, Field
from letta.helpers.datetime_helpers import get_utc_time
from letta.schemas.enums import RunStatus
from letta.schemas.enums import PrimitiveType, RunStatus
from letta.schemas.job import LettaRequestConfig
from letta.schemas.letta_base import LettaBase
from letta.schemas.letta_stop_reason import StopReasonType
class RunBase(LettaBase):
__id_prefix__ = "run"
__id_prefix__ = PrimitiveType.RUN.value
class Run(RunBase):

View File

@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, model_validator
from letta.constants import LETTA_TOOL_EXECUTION_DIR
from letta.schemas.agent import AgentState
from letta.schemas.enums import SandboxType
from letta.schemas.enums import PrimitiveType, SandboxType
from letta.schemas.letta_base import LettaBase, OrmMetadataBase
from letta.schemas.pip_requirement import PipRequirement
from letta.services.tool_sandbox.modal_constants import DEFAULT_MODAL_TIMEOUT
@@ -92,7 +92,7 @@ class ModalSandboxConfig(BaseModel):
class SandboxConfigBase(OrmMetadataBase):
__id_prefix__ = "sandbox"
__id_prefix__ = PrimitiveType.SANDBOX_CONFIG.value
class SandboxConfig(SandboxConfigBase):

View File

@@ -5,7 +5,7 @@ from pydantic import Field
from letta.helpers.tpuf_client import should_use_tpuf
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import VectorDBProvider
from letta.schemas.enums import PrimitiveType, VectorDBProvider
from letta.schemas.letta_base import LettaBase
@@ -14,7 +14,7 @@ class BaseSource(LettaBase):
Shared attributes across all source schemas.
"""
__id_prefix__ = "source"
__id_prefix__ = PrimitiveType.SOURCE.value
# Core source fields
name: str = Field(..., description="The name of the source.")

View File

@@ -3,14 +3,14 @@ from typing import Dict, List, Literal, Optional
from pydantic import Field
from letta.schemas.enums import StepStatus
from letta.schemas.enums import PrimitiveType, StepStatus
from letta.schemas.letta_base import LettaBase
from letta.schemas.letta_stop_reason import StopReasonType
from letta.schemas.message import Message
class StepBase(LettaBase):
__id_prefix__ = "step"
__id_prefix__ = PrimitiveType.STEP.value
class Step(StepBase):

View File

@@ -11,6 +11,7 @@ from letta.constants import (
LETTA_VOICE_TOOL_MODULE_NAME,
MCP_TOOL_TAG_NAME_PREFIX,
)
from letta.schemas.enums import PrimitiveType
# MCP Tool metadata constants for schema health status
MCP_TOOL_METADATA_SCHEMA_STATUS = f"{MCP_TOOL_TAG_NAME_PREFIX}:SCHEMA_STATUS"
@@ -28,7 +29,7 @@ logger = get_logger(__name__)
class BaseTool(LettaBase):
__id_prefix__ = "tool"
__id_prefix__ = PrimitiveType.TOOL.value
class Tool(BaseTool):

View File

@@ -58,7 +58,7 @@ from letta.schemas.agent import (
)
from letta.schemas.block import DEFAULT_BLOCKS, Block as PydanticBlock, BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import AgentType, ProviderType, TagMatchMode, ToolType, VectorDBProvider
from letta.schemas.enums import AgentType, PrimitiveType, ProviderType, TagMatchMode, ToolType, VectorDBProvider
from letta.schemas.file import FileMetadata as PydanticFileMetadata
from letta.schemas.group import Group as PydanticGroup, ManagerType
from letta.schemas.llm_config import LLMConfig
@@ -110,7 +110,7 @@ from letta.services.source_manager import SourceManager
from letta.services.tool_manager import ToolManager
from letta.settings import DatabaseChoice, model_settings, settings
from letta.utils import calculate_file_defaults_based_on_context_window, enforce_types, united_diff
from letta.validators import is_valid_id
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
@@ -672,6 +672,7 @@ class AgentManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def update_agent_async(
self,
agent_id: str,
@@ -976,6 +977,7 @@ class AgentManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def get_agent_by_id_async(
self,
agent_id: str,
@@ -984,10 +986,6 @@ class AgentManager:
) -> PydanticAgentState:
"""Fetch an agent by its ID."""
# Check if agent_id matches uuid4 format
if not is_valid_id("agent", agent_id):
raise LettaAgentNotFoundError(f"agent_id {agent_id} is not in the valid format 'agent-<uuid4>'")
async with db_registry.async_session() as session:
try:
query = select(AgentModel)
@@ -1039,6 +1037,7 @@ class AgentManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def get_agent_archive_ids_async(self, agent_id: str, actor: PydanticUser) -> List[str]:
"""Get all archive IDs associated with an agent."""
from letta.orm import ArchivesAgents
@@ -1052,6 +1051,7 @@ class AgentManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def validate_agent_exists_async(self, agent_id: str, actor: PydanticUser) -> None:
"""
Validate that an agent exists and user has access to it.
@@ -1069,6 +1069,7 @@ class AgentManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def delete_agent_async(self, agent_id: str, actor: PydanticUser) -> None:
"""
Deletes an agent and its associated relationships.
@@ -1131,6 +1132,7 @@ class AgentManager:
# TODO: This can also be made more efficient, instead of getting, setting, we can do it all in one db session for one query.
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def get_in_context_messages(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
return await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
@@ -1143,6 +1145,7 @@ class AgentManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def get_system_message_async(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor)
return await self.message_manager.get_message_by_id_async(message_id=agent.message_ids[0], actor=actor)
@@ -1324,6 +1327,7 @@ class AgentManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def set_in_context_messages_async(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
return await self.update_agent_async(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
@@ -1434,6 +1438,7 @@ class AgentManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def update_memory_if_changed_async(self, agent_id: str, new_memory: Memory, actor: PydanticUser) -> PydanticAgentState:
"""
Update internal memory object and system prompt if there have been modifications.
@@ -1546,6 +1551,8 @@ class AgentManager:
# ======================================================================================================================
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
async def attach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Attaches a source to an agent.
@@ -1615,6 +1622,7 @@ class AgentManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def append_system_message_async(self, agent_id: str, content: str, actor: PydanticUser):
"""
Async version of append_system_message.
@@ -1702,6 +1710,8 @@ class AgentManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
async def detach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Detaches a source from an agent.
@@ -1787,6 +1797,8 @@ class AgentManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
async def attach_block_async(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState:
"""Attaches a block to an agent. For sleeptime agents, also attaches to paired agents in the same group."""
async with db_registry.async_session() as session:
@@ -2373,8 +2385,9 @@ class AgentManager:
# Tool Management
# ======================================================================================================================
@enforce_types
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@raise_on_invalid_id(param_name="tool_id", expected_prefix=PrimitiveType.TOOL)
async def attach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> None:
"""
Attaches a tool to an agent.
@@ -2443,6 +2456,7 @@ class AgentManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def bulk_attach_tools_async(self, agent_id: str, tool_ids: List[str], actor: PydanticUser) -> None:
"""
Efficiently attaches multiple tools to an agent in a single operation.
@@ -2608,6 +2622,8 @@ class AgentManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@raise_on_invalid_id(param_name="tool_id", expected_prefix=PrimitiveType.TOOL)
async def detach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> None:
"""
Detaches a tool from an agent.
@@ -2637,6 +2653,7 @@ class AgentManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def bulk_detach_tools_async(self, agent_id: str, tool_ids: List[str], actor: PydanticUser) -> None:
"""
Efficiently detaches multiple tools from an agent in a single operation.
@@ -2673,6 +2690,7 @@ class AgentManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def modify_approvals_async(self, agent_id: str, tool_name: str, requires_approval: bool, actor: PydanticUser) -> None:
def is_target_rule(rule):
return rule.tool_name == tool_name and rule.type == "requires_approval"
@@ -3021,6 +3039,7 @@ class AgentManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def get_agent_files_config_async(self, agent_id: str, actor: PydanticUser) -> Tuple[int, int]:
"""Get per_file_view_window_char_limit and max_files_open for an agent.
@@ -3077,6 +3096,7 @@ class AgentManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def get_agent_max_files_open_async(self, agent_id: str, actor: PydanticUser) -> int:
"""Get max_files_open for an agent.
@@ -3105,6 +3125,7 @@ class AgentManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def get_agent_per_file_view_window_char_limit_async(self, agent_id: str, actor: PydanticUser) -> int:
"""Get per_file_view_window_char_limit for an agent.
@@ -3131,7 +3152,9 @@ class AgentManager:
return row
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def get_context_window(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview:
agent_state, system_message, num_messages, num_archival_memories = await self.rebuild_system_prompt_async(
agent_id=agent_id, actor=actor, force=True, dry_run=True

View File

@@ -7,11 +7,12 @@ from letta.log import get_logger
from letta.orm import ArchivalPassage, Archive as ArchiveModel, ArchivesAgents
from letta.otel.tracing import trace_method
from letta.schemas.archive import Archive as PydanticArchive
from letta.schemas.enums import VectorDBProvider
from letta.schemas.enums import PrimitiveType, VectorDBProvider
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.settings import settings
from letta.utils import enforce_types
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
@@ -47,6 +48,7 @@ class ArchiveManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
async def get_archive_by_id_async(
self,
archive_id: str,
@@ -63,6 +65,7 @@ class ArchiveManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
async def update_archive_async(
self,
archive_id: str,
@@ -89,6 +92,7 @@ class ArchiveManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def list_archives_async(
self,
*,
@@ -136,6 +140,8 @@ class ArchiveManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
async def attach_agent_to_archive_async(
self,
agent_id: str,
@@ -172,6 +178,7 @@ class ArchiveManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def get_default_archive_for_agent_async(
self,
agent_id: str,
@@ -204,6 +211,7 @@ class ArchiveManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
async def delete_archive_async(
self,
archive_id: str,
@@ -221,6 +229,7 @@ class ArchiveManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def get_or_create_default_archive_for_agent_async(
self,
agent_id: str,
@@ -291,6 +300,7 @@ class ArchiveManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
async def get_agents_for_archive_async(
self,
archive_id: str,
@@ -333,6 +343,7 @@ class ArchiveManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
async def get_or_set_vector_db_namespace_async(
self,
archive_id: str,

View File

@@ -15,11 +15,12 @@ from letta.orm.errors import NoResultFound
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.block import Block as PydanticBlock, BlockUpdate
from letta.schemas.enums import ActorType
from letta.schemas.enums import ActorType, PrimitiveType
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.settings import DatabaseChoice, settings
from letta.utils import enforce_types
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
@@ -134,10 +135,9 @@ class BlockManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
async def update_block_async(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
"""Update a block by its ID with the given BlockUpdate object."""
# Safety check for block
async with db_registry.async_session() as session:
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
@@ -155,6 +155,7 @@ class BlockManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
async def delete_block_async(self, block_id: str, actor: PydanticUser) -> None:
"""Delete a block by its ID."""
async with db_registry.async_session() as session:
@@ -353,6 +354,7 @@ class BlockManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
async def get_block_by_id_async(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
"""Retrieve a block by its name."""
async with db_registry.async_session() as session:
@@ -412,6 +414,7 @@ class BlockManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
async def get_agents_for_block_async(
self,
block_id: str,
@@ -595,6 +598,8 @@ class BlockManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def checkpoint_block_async(
self,
block_id: str,
@@ -703,6 +708,7 @@ class BlockManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
async def undo_checkpoint_block(
self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None
) -> PydanticBlock:
@@ -753,6 +759,7 @@ class BlockManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
async def redo_checkpoint_block(
self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None
) -> PydanticBlock:

View File

@@ -15,7 +15,7 @@ from letta.orm.errors import NoResultFound
from letta.orm.file import FileContent as FileContentModel, FileMetadata as FileMetadataModel
from letta.orm.sqlalchemy_base import AccessType
from letta.otel.tracing import trace_method
from letta.schemas.enums import FileProcessingStatus
from letta.schemas.enums import FileProcessingStatus, PrimitiveType
from letta.schemas.file import FileMetadata as PydanticFileMetadata
from letta.schemas.source import Source as PydanticSource
from letta.schemas.source_metadata import FileStats, OrganizationSourcesStats, SourceStats
@@ -23,6 +23,7 @@ from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.settings import settings
from letta.utils import enforce_types
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
@@ -93,6 +94,7 @@ class FileManager:
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
# @async_redis_cache(
# key_func=lambda self, file_id, actor=None, include_content=False, strip_directory_prefix=False: f"{file_id}:{actor.organization_id if actor else 'none'}:{include_content}:{strip_directory_prefix}",
# prefix="file_content",
@@ -135,6 +137,7 @@ class FileManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
async def update_file_status(
self,
*,
@@ -171,7 +174,6 @@ class FileManager:
* 1st round-trip → UPDATE with optional state validation
* 2nd round-trip → SELECT fresh row (same as read_async) if update succeeded
"""
if processing_status is None and error_message is None and total_chunks is None and chunks_embedded is None:
raise ValueError("Nothing to update")
@@ -353,6 +355,7 @@ class FileManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
async def upsert_file_content(
self,
*,
@@ -398,6 +401,7 @@ class FileManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
async def list_files(
self,
source_id: str,
@@ -455,6 +459,7 @@ class FileManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
async def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata:
"""Delete a file by its ID."""
async with db_registry.async_session() as session:
@@ -509,6 +514,7 @@ class FileManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
# @async_redis_cache(
# key_func=lambda self, original_filename, source_id, actor: f"{original_filename}:{source_id}:{actor.organization_id}",
# prefix="file_by_name",

View File

@@ -9,6 +9,7 @@ from letta.orm.errors import NoResultFound
from letta.orm.group import Group as GroupModel
from letta.orm.message import Message as MessageModel
from letta.otel.tracing import trace_method
from letta.schemas.enums import PrimitiveType
from letta.schemas.group import Group as PydanticGroup, GroupCreate, GroupUpdate, InternalTemplateGroupCreate, ManagerType
from letta.schemas.letta_message import LettaMessage
from letta.schemas.message import Message as PydanticMessage
@@ -16,6 +17,7 @@ from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.settings import DatabaseChoice, settings
from letta.utils import enforce_types
from letta.validators import raise_on_invalid_id
class GroupManager:
@@ -62,6 +64,7 @@ class GroupManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
async def retrieve_group_async(self, group_id: str, actor: PydanticUser) -> PydanticGroup:
async with db_registry.async_session() as session:
group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor)
@@ -119,6 +122,7 @@ class GroupManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
async def modify_group_async(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup:
async with db_registry.async_session() as session:
group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor)
@@ -182,6 +186,7 @@ class GroupManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
async def delete_group_async(self, group_id: str, actor: PydanticUser) -> None:
async with db_registry.async_session() as session:
group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor)
@@ -189,6 +194,7 @@ class GroupManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
async def list_group_messages_async(
self,
actor: PydanticUser,
@@ -226,6 +232,7 @@ class GroupManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
async def reset_messages_async(self, group_id: str, actor: PydanticUser) -> None:
async with db_registry.async_session() as session:
# Ensure group is loadable by user
@@ -241,6 +248,7 @@ class GroupManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
async def bump_turns_counter_async(self, group_id: str, actor: PydanticUser) -> int:
async with db_registry.async_session() as session:
# Ensure group is loadable by user
@@ -253,6 +261,8 @@ class GroupManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
@raise_on_invalid_id(param_name="last_processed_message_id", expected_prefix=PrimitiveType.MESSAGE)
async def get_last_processed_message_id_and_update_async(
self, group_id: str, last_processed_message_id: str, actor: PydanticUser
) -> str:

View File

@@ -12,6 +12,7 @@ from letta.orm.identity import Identity as IdentityModel
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
from letta.schemas.block import Block
from letta.schemas.enums import PrimitiveType
from letta.schemas.identity import (
Identity as PydanticIdentity,
IdentityCreate,
@@ -24,6 +25,7 @@ from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.settings import DatabaseChoice, settings
from letta.utils import enforce_types
from letta.validators import raise_on_invalid_id
class IdentityManager:
@@ -62,6 +64,7 @@ class IdentityManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
async def get_identity_async(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity:
async with db_registry.async_session() as session:
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
@@ -143,6 +146,7 @@ class IdentityManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
async def update_identity_async(
self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False
) -> PydanticIdentity:
@@ -206,6 +210,7 @@ class IdentityManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
async def upsert_identity_properties_async(
self, identity_id: str, properties: List[IdentityProperty], actor: PydanticUser
) -> PydanticIdentity:
@@ -223,6 +228,7 @@ class IdentityManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
async def delete_identity_async(self, identity_id: str, actor: PydanticUser) -> None:
async with db_registry.async_session() as session:
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
@@ -280,6 +286,7 @@ class IdentityManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
async def list_agents_for_identity_async(
self,
identity_id: str,
@@ -311,6 +318,7 @@ class IdentityManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
async def list_blocks_for_identity_async(
self,
identity_id: str,

View File

@@ -14,7 +14,7 @@ from letta.orm.message import Message as MessageModel
from letta.orm.sqlalchemy_base import AccessType
from letta.orm.step import Step, Step as StepModel
from letta.otel.tracing import log_event, trace_method
from letta.schemas.enums import JobStatus, JobType, MessageRole
from letta.schemas.enums import JobStatus, JobType, MessageRole, PrimitiveType
from letta.schemas.job import BatchJob as PydanticBatchJob, Job as PydanticJob, JobUpdate, LettaRequestConfig
from letta.schemas.letta_message import LettaMessage
from letta.schemas.letta_stop_reason import StopReasonType
@@ -26,6 +26,7 @@ from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.services.helpers.agent_manager_helper import validate_agent_exists_async
from letta.utils import enforce_types
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
@@ -70,6 +71,7 @@ class JobManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
async def update_job_by_id_async(
self, job_id: str, job_update: JobUpdate, actor: PydanticUser, safe_update: bool = False
) -> PydanticJob:
@@ -147,6 +149,7 @@ class JobManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
async def safe_update_job_status_async(
self,
job_id: str,
@@ -187,6 +190,7 @@ class JobManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
async def get_job_by_id_async(self, job_id: str, actor: PydanticUser) -> PydanticJob:
"""Fetch a job by its ID asynchronously."""
async with db_registry.async_session() as session:
@@ -301,6 +305,7 @@ class JobManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
async def delete_job_by_id_async(self, job_id: str, actor: PydanticUser) -> PydanticJob:
"""Delete a job by its ID."""
async with db_registry.async_session() as session:
@@ -310,6 +315,7 @@ class JobManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
async def get_run_messages(
self,
run_id: str,
@@ -367,6 +373,7 @@ class JobManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
async def get_step_messages(
self,
run_id: str,
@@ -447,6 +454,7 @@ class JobManager:
return job
@enforce_types
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
async def record_ttft(self, job_id: str, ttft_ns: int, actor: PydanticUser) -> None:
"""Record time to first token for a run"""
try:
@@ -459,6 +467,7 @@ class JobManager:
logger.warning(f"Failed to record TTFT for job {job_id}: {e}")
@enforce_types
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
async def record_response_duration(self, job_id: str, total_duration_ns: int, actor: PydanticUser) -> None:
"""Record total response duration for a run"""
try:
@@ -529,6 +538,7 @@ class JobManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
async def get_job_steps(
self,
job_id: str,

View File

@@ -25,6 +25,7 @@ from letta.orm.errors import NoResultFound
from letta.orm.mcp_oauth import MCPOAuth, OAuthSessionStatus
from letta.orm.mcp_server import MCPServer as MCPServerModel
from letta.orm.tool import Tool as ToolModel
from letta.schemas.enums import PrimitiveType
from letta.schemas.mcp import (
MCPOAuthSession,
MCPOAuthSessionCreate,
@@ -47,6 +48,7 @@ from letta.services.mcp.streamable_http_client import AsyncStreamableHTTPMCPClie
from letta.services.tool_manager import ToolManager
from letta.settings import settings, tool_settings
from letta.utils import enforce_types, printd, safe_create_task_with_return
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
@@ -60,6 +62,7 @@ class MCPManager:
self.cached_mcp_servers = {} # maps id -> async connection
@enforce_types
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
async def list_mcp_server_tools(self, mcp_server_name: str, actor: PydanticUser, agent_id: Optional[str] = None) -> List[MCPTool]:
"""Get a list of all tools for a specific MCP server."""
mcp_client = None

View File

@@ -10,7 +10,7 @@ from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.orm.message import Message as MessageModel
from letta.otel.tracing import trace_method
from letta.schemas.enums import MessageRole
from letta.schemas.enums import MessageRole, PrimitiveType
from letta.schemas.letta_message import LettaMessageUpdateUnion
from letta.schemas.letta_message_content import ImageSourceType, LettaImage, MessageContentType, TextContent
from letta.schemas.message import Message as PydanticMessage, MessageSearchResult, MessageUpdate
@@ -20,6 +20,7 @@ from letta.services.file_manager import FileManager
from letta.services.helpers.agent_manager_helper import validate_agent_exists_async
from letta.settings import DatabaseChoice, settings
from letta.utils import enforce_types, fire_and_forget
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
@@ -308,6 +309,7 @@ class MessageManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="message_id", expected_prefix=PrimitiveType.MESSAGE)
async def get_message_by_id_async(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]:
"""Fetch a message by ID."""
async with db_registry.async_session() as session:
@@ -712,6 +714,7 @@ class MessageManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="message_id", expected_prefix=PrimitiveType.MESSAGE)
async def delete_message_by_id_async(self, message_id: str, actor: PydanticUser, strict_mode: bool = False) -> bool:
"""Delete a message (async version with turbopuffer support)."""
# capture agent_id before deletion

View File

@@ -2,12 +2,13 @@ from typing import List, Optional, Tuple, Union
from letta.orm.provider import Provider as ProviderModel
from letta.otel.tracing import trace_method
from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.enums import PrimitiveType, ProviderCategory, ProviderType
from letta.schemas.providers import Provider as PydanticProvider, ProviderCheck, ProviderCreate, ProviderUpdate
from letta.schemas.secret import Secret
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types
from letta.validators import raise_on_invalid_id
class ProviderManager:
@@ -40,6 +41,7 @@ class ProviderManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
async def update_provider_async(self, provider_id: str, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider:
"""Update provider details."""
async with db_registry.async_session() as session:
@@ -103,6 +105,7 @@ class ProviderManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
async def delete_provider_by_id_async(self, provider_id: str, actor: PydanticUser):
"""Delete a provider."""
async with db_registry.async_session() as session:
@@ -153,6 +156,7 @@ class ProviderManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
async def get_provider_async(self, provider_id: str, actor: PydanticUser) -> PydanticProvider:
async with db_registry.async_session() as session:
provider_model = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=actor)

View File

@@ -14,7 +14,7 @@ from letta.orm.run_metrics import RunMetrics as RunMetricsModel
from letta.orm.sqlalchemy_base import AccessType
from letta.orm.step import Step as StepModel
from letta.otel.tracing import log_event, trace_method
from letta.schemas.enums import AgentType, ComparisonOperator, MessageRole, RunStatus
from letta.schemas.enums import AgentType, ComparisonOperator, MessageRole, RunStatus, PrimitiveType
from letta.schemas.job import LettaRequestConfig
from letta.schemas.letta_message import LettaMessage, LettaMessageUnion
from letta.schemas.letta_response import LettaResponse
@@ -31,6 +31,7 @@ from letta.services.helpers.agent_manager_helper import validate_agent_exists_as
from letta.services.message_manager import MessageManager
from letta.services.step_manager import StepManager
from letta.utils import enforce_types
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
@@ -85,6 +86,7 @@ class RunManager:
return run.to_pydantic()
@enforce_types
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
async def get_run_by_id(self, run_id: str, actor: PydanticUser) -> PydanticRun:
"""Get a run by its ID."""
async with db_registry.async_session() as session:
@@ -176,6 +178,7 @@ class RunManager:
return [run.to_pydantic() for run in runs]
@enforce_types
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
async def delete_run(self, run_id: str, actor: PydanticUser) -> PydanticRun:
"""Delete a run by its ID."""
async with db_registry.async_session() as session:
@@ -189,11 +192,11 @@ class RunManager:
return pydantic_run
@enforce_types
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
async def update_run_by_id_async(
self, run_id: str, update: RunUpdate, actor: PydanticUser, refresh_result_messages: bool = True
) -> PydanticRun:
"""Update a run using a RunUpdate object."""
async with db_registry.async_session() as session:
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor)
@@ -327,6 +330,7 @@ class RunManager:
return result
@enforce_types
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
async def get_run_usage(self, run_id: str, actor: PydanticUser) -> LettaUsageStatistics:
"""Get usage statistics for a run."""
async with db_registry.async_session() as session:
@@ -344,6 +348,7 @@ class RunManager:
return total_usage
@enforce_types
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
async def get_run_messages(
self,
run_id: str,
@@ -378,6 +383,7 @@ class RunManager:
return letta_messages
@enforce_types
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
async def get_run_request_config(self, run_id: str, actor: PydanticUser) -> Optional[LettaRequestConfig]:
"""Get the letta request config from a run."""
async with db_registry.async_session() as session:
@@ -388,6 +394,7 @@ class RunManager:
return pydantic_run.request_config
@enforce_types
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
async def get_run_metrics_async(self, run_id: str, actor: PydanticUser) -> PydanticRunMetrics:
"""Get metrics for a run."""
async with db_registry.async_session() as session:
@@ -395,6 +402,7 @@ class RunManager:
return metrics.to_pydantic()
@enforce_types
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
async def get_run_steps(
self,
run_id: str,

View File

@@ -5,7 +5,7 @@ from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.orm.sandbox_config import SandboxConfig as SandboxConfigModel, SandboxEnvironmentVariable as SandboxEnvVarModel
from letta.otel.tracing import trace_method
from letta.schemas.enums import SandboxType
from letta.schemas.enums import PrimitiveType, SandboxType
from letta.schemas.environment_variables import (
SandboxEnvironmentVariable as PydanticEnvVar,
SandboxEnvironmentVariableCreate,
@@ -20,6 +20,7 @@ from letta.schemas.sandbox_config import (
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types, printd
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
@@ -101,6 +102,7 @@ class SandboxConfigManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
async def update_sandbox_config_async(
self, sandbox_config_id: str, sandbox_update: SandboxConfigUpdate, actor: PydanticUser
) -> PydanticSandboxConfig:
@@ -129,6 +131,7 @@ class SandboxConfigManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
async def delete_sandbox_config_async(self, sandbox_config_id: str, actor: PydanticUser) -> PydanticSandboxConfig:
"""Delete a sandbox configuration by its ID."""
async with db_registry.async_session() as session:
@@ -176,6 +179,7 @@ class SandboxConfigManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
async def create_sandbox_env_var_async(
self, env_var_create: SandboxEnvironmentVariableCreate, sandbox_config_id: str, actor: PydanticUser
) -> PydanticEnvVar:
@@ -266,6 +270,7 @@ class SandboxConfigManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
async def list_sandbox_env_vars_async(
self,
sandbox_config_id: str,
@@ -302,6 +307,7 @@ class SandboxConfigManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
def get_sandbox_env_vars_as_dict(
self, sandbox_config_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50
) -> Dict[str, str]:
@@ -315,6 +321,7 @@ class SandboxConfigManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
async def get_sandbox_env_vars_as_dict_async(
self, sandbox_config_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50
) -> Dict[str, str]:
@@ -324,6 +331,7 @@ class SandboxConfigManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
async def get_sandbox_env_var_by_key_and_sandbox_config_id_async(
self, key: str, sandbox_config_id: str, actor: Optional[PydanticUser] = None
) -> Optional[PydanticEnvVar]:

View File

@@ -11,11 +11,12 @@ from letta.orm.source import Source as SourceModel
from letta.orm.sources_agents import SourcesAgents
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.enums import VectorDBProvider
from letta.schemas.enums import PrimitiveType, VectorDBProvider
from letta.schemas.source import Source as PydanticSource, SourceUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types, printd
from letta.validators import raise_on_invalid_id
class SourceManager:
@@ -201,6 +202,7 @@ class SourceManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
async def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource:
"""Update a source by its ID with the given SourceUpdate object."""
async with db_registry.async_session() as session:
@@ -224,6 +226,7 @@ class SourceManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
async def delete_source(self, source_id: str, actor: PydanticUser) -> PydanticSource:
"""Delete a source by its ID."""
async with db_registry.async_session() as session:
@@ -268,6 +271,7 @@ class SourceManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
async def list_attached_agents(
self, source_id: str, actor: PydanticUser, ids_only: bool = False
) -> Union[List[PydanticAgentState], List[str]]:
@@ -321,6 +325,7 @@ class SourceManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
async def get_agents_for_source_id(self, source_id: str, actor: PydanticUser) -> List[str]:
"""
Get all agent IDs associated with a given source ID.
@@ -347,6 +352,7 @@ class SourceManager:
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
async def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]:
"""Retrieve a source by its ID."""
async with db_registry.async_session() as session:

View File

@@ -13,7 +13,7 @@ from letta.orm.sqlalchemy_base import AccessType
from letta.orm.step import Step as StepModel
from letta.orm.step_metrics import StepMetrics as StepMetricsModel
from letta.otel.tracing import get_trace_id, trace_method
from letta.schemas.enums import StepStatus
from letta.schemas.enums import PrimitiveType, StepStatus
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_response import UsageStatistics
@@ -22,6 +22,7 @@ from letta.schemas.step_metrics import StepMetrics as PydanticStepMetrics
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types
from letta.validators import raise_on_invalid_id
class FeedbackType(str, Enum):
@@ -32,6 +33,8 @@ class FeedbackType(str, Enum):
class StepManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
async def list_steps_async(
self,
actor: PydanticUser,
@@ -79,6 +82,10 @@ class StepManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
def log_step(
self,
actor: PydanticUser,
@@ -133,6 +140,10 @@ class StepManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
async def log_step_async(
self,
actor: PydanticUser,
@@ -196,6 +207,7 @@ class StepManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
async def get_step_async(self, step_id: str, actor: PydanticUser) -> PydanticStep:
async with db_registry.async_session() as session:
step = await StepModel.read_async(db_session=session, identifier=step_id, actor=actor)
@@ -203,6 +215,7 @@ class StepManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
async def get_step_metrics_async(self, step_id: str, actor: PydanticUser) -> PydanticStepMetrics:
async with db_registry.async_session() as session:
metrics = await StepMetricsModel.read_async(db_session=session, identifier=step_id, actor=actor)
@@ -210,6 +223,7 @@ class StepManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
async def add_feedback_async(
self, step_id: str, feedback: FeedbackType | None, actor: PydanticUser, tags: list[str] | None = None
) -> PydanticStep:
@@ -225,6 +239,7 @@ class StepManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
async def update_step_transaction_id(self, actor: PydanticUser, step_id: str, transaction_id: str) -> PydanticStep:
"""Update the transaction ID for a step.
@@ -252,6 +267,7 @@ class StepManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
async def list_step_messages_async(
self,
step_id: str,
@@ -276,6 +292,7 @@ class StepManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
async def update_step_stop_reason(self, actor: PydanticUser, step_id: str, stop_reason: StopReasonType) -> PydanticStep:
"""Update the stop reason for a step.
@@ -303,6 +320,7 @@ class StepManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
async def update_step_error_async(
self,
actor: PydanticUser,
@@ -348,6 +366,7 @@ class StepManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
async def update_step_success_async(
self,
actor: PydanticUser,
@@ -388,6 +407,7 @@ class StepManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
async def update_step_cancelled_async(
self,
actor: PydanticUser,
@@ -423,6 +443,9 @@ class StepManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
async def record_step_metrics_async(
self,
actor: PydanticUser,

View File

@@ -27,7 +27,7 @@ from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.orm.tool import Tool as ToolModel
from letta.otel.tracing import trace_method
from letta.schemas.enums import ToolType
from letta.schemas.enums import PrimitiveType, ToolType
from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
@@ -36,6 +36,7 @@ from letta.services.mcp.types import SSEServerConfig, StdioServerConfig
from letta.services.tool_schema_generator import generate_schema_for_tool_creation, generate_schema_for_tool_update
from letta.settings import settings
from letta.utils import enforce_types, printd
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
@@ -203,6 +204,7 @@ class ToolManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="tool_id", expected_prefix=PrimitiveType.TOOL)
async def get_tool_by_id_async(self, tool_id: str, actor: PydanticUser) -> PydanticTool:
"""Fetch a tool by its ID."""
async with db_registry.async_session() as session:
@@ -235,6 +237,7 @@ class ToolManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="tool_id", expected_prefix=PrimitiveType.TOOL)
async def tool_exists_async(self, tool_id: str, actor: PydanticUser) -> bool:
"""Check if a tool exists and belongs to the user's organization (lightweight check)."""
async with db_registry.async_session() as session:
@@ -507,6 +510,7 @@ class ToolManager:
@enforce_types
@trace_method
@raise_on_invalid_id(param_name="tool_id", expected_prefix=PrimitiveType.TOOL)
async def update_tool_by_id_async(
self,
tool_id: str,
@@ -604,6 +608,7 @@ class ToolManager:
@enforce_types
@trace_method
# @raise_on_invalid_id This is commented out bc it's called by _list_tools_async, when it encounters malformed tools (i.e. if id is invalid will fail validation on deletion)
async def delete_tool_by_id_async(self, tool_id: str, actor: PydanticUser) -> None:
"""Delete a tool by its ID."""
async with db_registry.async_session() as session:

View File

@@ -1,48 +1,21 @@
import inspect
import re
from functools import wraps
from typing import Annotated
from fastapi import Path
from letta.schemas.agent import AgentState
from letta.schemas.archive import ArchiveBase
from letta.schemas.block import BaseBlock
from letta.schemas.file import FileMetadataBase
from letta.schemas.folder import BaseFolder
from letta.schemas.group import GroupBase
from letta.schemas.identity import IdentityBase
from letta.schemas.job import JobBase
from letta.schemas.message import BaseMessage
from letta.schemas.providers import ProviderBase
from letta.schemas.run import RunBase
from letta.schemas.sandbox_config import SandboxConfigBase
from letta.schemas.source import BaseSource
from letta.schemas.step import StepBase
from letta.schemas.tool import BaseTool
from letta.errors import LettaInvalidArgumentError
from letta.schemas.enums import PrimitiveType # PrimitiveType is now in schemas.enums
# TODO: extract this list from routers/v1/__init__.py and ROUTERS
primitives = [
AgentState.__id_prefix__,
BaseMessage.__id_prefix__,
RunBase.__id_prefix__,
JobBase.__id_prefix__,
GroupBase.__id_prefix__,
BaseBlock.__id_prefix__,
FileMetadataBase.__id_prefix__,
BaseFolder.__id_prefix__,
BaseSource.__id_prefix__,
BaseTool.__id_prefix__,
ArchiveBase.__id_prefix__,
ProviderBase.__id_prefix__,
SandboxConfigBase.__id_prefix__,
StepBase.__id_prefix__,
IdentityBase.__id_prefix__,
]
# Map from PrimitiveType to the actual prefix string (which is just the enum value)
PRIMITIVE_ID_PREFIXES = {primitive_type: primitive_type.value for primitive_type in PrimitiveType}
PRIMITIVE_ID_PATTERNS = {
# f-string interpolation gets confused because of the regex's required curly braces {}
primitive: re.compile("^" + primitive + "-[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$")
for primitive in primitives
prefix: re.compile("^" + prefix + "-[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$")
for prefix in PRIMITIVE_ID_PREFIXES.values()
}
@@ -67,28 +40,65 @@ def _create_path_validator_factory(primitive: str):
# PATH_VALIDATORS now contains factory functions, not Path objects
# Usage: folder_id: str = PATH_VALIDATORS[BaseFolder.__id_prefix__]()
PATH_VALIDATORS = {primitive: _create_path_validator_factory(primitive) for primitive in primitives}
def is_valid_id(primitive: str, id: str) -> bool:
return PRIMITIVE_ID_PATTERNS[primitive].match(id) is not None
# Usage: folder_id: str = PATH_VALIDATORS[PrimitiveType.FOLDER.value]()
PATH_VALIDATORS = {primitive_type.value: _create_path_validator_factory(primitive_type.value) for primitive_type in PrimitiveType}
# Type aliases for common ID types
# These can be used directly in route handler signatures for cleaner code
AgentId = Annotated[str, PATH_VALIDATORS[AgentState.__id_prefix__]()]
ToolId = Annotated[str, PATH_VALIDATORS[BaseTool.__id_prefix__]()]
SourceId = Annotated[str, PATH_VALIDATORS[BaseSource.__id_prefix__]()]
BlockId = Annotated[str, PATH_VALIDATORS[BaseBlock.__id_prefix__]()]
MessageId = Annotated[str, PATH_VALIDATORS[BaseMessage.__id_prefix__]()]
RunId = Annotated[str, PATH_VALIDATORS[RunBase.__id_prefix__]()]
JobId = Annotated[str, PATH_VALIDATORS[JobBase.__id_prefix__]()]
GroupId = Annotated[str, PATH_VALIDATORS[GroupBase.__id_prefix__]()]
FileId = Annotated[str, PATH_VALIDATORS[FileMetadataBase.__id_prefix__]()]
FolderId = Annotated[str, PATH_VALIDATORS[BaseFolder.__id_prefix__]()]
ArchiveId = Annotated[str, PATH_VALIDATORS[ArchiveBase.__id_prefix__]()]
ProviderId = Annotated[str, PATH_VALIDATORS[ProviderBase.__id_prefix__]()]
SandboxConfigId = Annotated[str, PATH_VALIDATORS[SandboxConfigBase.__id_prefix__]()]
StepId = Annotated[str, PATH_VALIDATORS[StepBase.__id_prefix__]()]
IdentityId = Annotated[str, PATH_VALIDATORS[IdentityBase.__id_prefix__]()]
AgentId = Annotated[str, PATH_VALIDATORS[PrimitiveType.AGENT.value]()]
ToolId = Annotated[str, PATH_VALIDATORS[PrimitiveType.TOOL.value]()]
SourceId = Annotated[str, PATH_VALIDATORS[PrimitiveType.SOURCE.value]()]
BlockId = Annotated[str, PATH_VALIDATORS[PrimitiveType.BLOCK.value]()]
MessageId = Annotated[str, PATH_VALIDATORS[PrimitiveType.MESSAGE.value]()]
RunId = Annotated[str, PATH_VALIDATORS[PrimitiveType.RUN.value]()]
JobId = Annotated[str, PATH_VALIDATORS[PrimitiveType.JOB.value]()]
GroupId = Annotated[str, PATH_VALIDATORS[PrimitiveType.GROUP.value]()]
FileId = Annotated[str, PATH_VALIDATORS[PrimitiveType.FILE.value]()]
FolderId = Annotated[str, PATH_VALIDATORS[PrimitiveType.FOLDER.value]()]
ArchiveId = Annotated[str, PATH_VALIDATORS[PrimitiveType.ARCHIVE.value]()]
ProviderId = Annotated[str, PATH_VALIDATORS[PrimitiveType.PROVIDER.value]()]
SandboxConfigId = Annotated[str, PATH_VALIDATORS[PrimitiveType.SANDBOX_CONFIG.value]()]
StepId = Annotated[str, PATH_VALIDATORS[PrimitiveType.STEP.value]()]
IdentityId = Annotated[str, PATH_VALIDATORS[PrimitiveType.IDENTITY.value]()]
def raise_on_invalid_id(param_name: str, expected_prefix: PrimitiveType):
"""
Decorator that validates an ID parameter has the expected prefix format.
Can be stacked multiple times on the same function to validate different IDs.
Args:
param_name: The name of the function parameter to validate (e.g., "agent_id")
expected_prefix: The expected primitive type (e.g., PrimitiveType.AGENT)
Example:
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
@raise_on_invalid_id(param_name="folder_id", expected_prefix=PrimitiveType.FOLDER)
def my_function(agent_id: str, folder_id: str):
pass
"""
def decorator(function):
@wraps(function)
def wrapper(*args, **kwargs):
sig = inspect.signature(function)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
if param_name in bound_args.arguments:
arg_value = bound_args.arguments[param_name]
if arg_value is not None:
prefix = PRIMITIVE_ID_PREFIXES[expected_prefix]
if PRIMITIVE_ID_PATTERNS[prefix].match(arg_value) is None:
raise LettaInvalidArgumentError(
message=f"Invalid {expected_prefix.value} ID format: {arg_value}. Expected format: '{prefix}-<uuid4>'",
argument_name=param_name,
)
return function(*args, **kwargs)
return wrapper
return decorator

View File

@@ -286,7 +286,7 @@ async def test_turbopuffer_metadata_attributes(default_user, enable_turbopuffer)
pytest.skip("No Turbopuffer API key available")
client = TurbopufferClient()
archive_id = f"test-archive-{datetime.now().timestamp()}"
archive_id = f"archive-{uuid.uuid4()}"
try:
# Insert passages with various metadata
@@ -391,7 +391,7 @@ async def test_hybrid_search_with_real_tpuf(default_user, enable_turbopuffer):
from letta.helpers.tpuf_client import TurbopufferClient
client = TurbopufferClient()
archive_id = f"test-hybrid-{datetime.now().timestamp()}"
archive_id = f"archive-{uuid.uuid4()}"
org_id = str(uuid.uuid4())
try:
@@ -497,7 +497,7 @@ async def test_tag_filtering_with_real_tpuf(default_user, enable_turbopuffer):
from letta.helpers.tpuf_client import TurbopufferClient
client = TurbopufferClient()
archive_id = f"test-tags-{datetime.now().timestamp()}"
archive_id = f"archive-{uuid.uuid4()}"
org_id = str(uuid.uuid4())
try:
@@ -628,7 +628,7 @@ async def test_temporal_filtering_with_real_tpuf(default_user, enable_turbopuffe
client = TurbopufferClient()
# Create a unique archive ID for this test
archive_id = f"test-temporal-{uuid.uuid4()}"
archive_id = f"archive-{uuid.uuid4()}"
try:
# Create passages with different timestamps

View File

@@ -44,7 +44,7 @@ from letta.constants import (
MULTI_AGENT_TOOLS,
)
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
from letta.errors import LettaAgentNotFoundError
from letta.errors import LettaAgentNotFoundError, LettaInvalidArgumentError
from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.functions.mcp_client.types import MCPTool
from letta.helpers import ToolRulesSolver
@@ -233,7 +233,7 @@ async def test_update_job_auto_complete(server: SyncServer, default_user):
async def test_get_job_not_found(server: SyncServer, default_user):
"""Test fetching a non-existent job."""
non_existent_job_id = "nonexistent-id"
with pytest.raises(NoResultFound):
with pytest.raises(LettaInvalidArgumentError):
await server.job_manager.get_job_by_id_async(non_existent_job_id, actor=default_user)
@@ -241,7 +241,7 @@ async def test_get_job_not_found(server: SyncServer, default_user):
async def test_delete_job_not_found(server: SyncServer, default_user):
"""Test deleting a non-existent job."""
non_existent_job_id = "nonexistent-id"
with pytest.raises(NoResultFound):
with pytest.raises(LettaInvalidArgumentError):
await server.job_manager.delete_job_by_id_async(non_existent_job_id, actor=default_user)
@@ -412,9 +412,6 @@ async def test_e2e_job_callback(monkeypatch, server: SyncServer, default_user):
# ======================================================================================================================
@pytest.mark.asyncio
async def test_record_ttft(server: SyncServer, default_user):
"""Test recording time to first token for a job."""
@@ -478,9 +475,11 @@ async def test_record_timing_metrics_together(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_record_timing_invalid_job(server: SyncServer, default_user):
"""Test recording timing metrics for non-existent job fails gracefully."""
# Try to record TTFT for non-existent job - should not raise exception but log warning
await server.job_manager.record_ttft("nonexistent_job_id", 1_000_000_000, default_user)
"""Test recording timing metrics for non-existent job raises LettaInvalidArgumentError."""
# Try to record TTFT for non-existent job - should raise LettaInvalidArgumentError
with pytest.raises(LettaInvalidArgumentError):
await server.job_manager.record_ttft("nonexistent_job_id", 1_000_000_000, default_user)
# Try to record response duration for non-existent job - should not raise exception but log warning
await server.job_manager.record_response_duration("nonexistent_job_id", 2_000_000_000, default_user)
# Try to record response duration for non-existent job - should raise LettaInvalidArgumentError
with pytest.raises(LettaInvalidArgumentError):
await server.job_manager.record_response_duration("nonexistent_job_id", 2_000_000_000, default_user)

View File

@@ -44,7 +44,7 @@ from letta.constants import (
MULTI_AGENT_TOOLS,
)
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
from letta.errors import LettaAgentNotFoundError
from letta.errors import LettaAgentNotFoundError, LettaInvalidArgumentError
from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.functions.mcp_client.types import MCPTool
from letta.helpers import ToolRulesSolver
@@ -241,7 +241,7 @@ async def test_update_run_auto_complete(server: SyncServer, default_user, sarah_
async def test_get_run_not_found(server: SyncServer, default_user):
"""Test fetching a non-existent run."""
non_existent_run_id = "nonexistent-id"
with pytest.raises(NoResultFound):
with pytest.raises(LettaInvalidArgumentError):
await server.run_manager.get_run_by_id(non_existent_run_id, actor=default_user)
@@ -249,7 +249,7 @@ async def test_get_run_not_found(server: SyncServer, default_user):
async def test_delete_run_not_found(server: SyncServer, default_user):
"""Test deleting a non-existent run."""
non_existent_run_id = "nonexistent-id"
with pytest.raises(NoResultFound):
with pytest.raises(LettaInvalidArgumentError):
await server.run_manager.delete_run(non_existent_run_id, actor=default_user)
@@ -1268,7 +1268,7 @@ async def test_run_usage_stats_get_nonexistent_run(server: SyncServer, default_u
"""Test getting usage statistics for a nonexistent run."""
run_manager = server.run_manager
with pytest.raises(NoResultFound):
with pytest.raises(LettaInvalidArgumentError):
await run_manager.get_run_usage(run_id="nonexistent_run", actor=default_user)
@@ -1307,7 +1307,7 @@ async def test_get_run_request_config_none(server: SyncServer, sarah_agent, defa
@pytest.mark.asyncio
async def test_get_run_request_config_nonexistent_run(server: SyncServer, default_user):
"""Test getting request config for a nonexistent run."""
with pytest.raises(NoResultFound):
with pytest.raises(LettaInvalidArgumentError):
await server.run_manager.get_run_request_config("nonexistent_run", actor=default_user)
@@ -1453,7 +1453,7 @@ async def test_run_metrics_num_steps_tracking(server: SyncServer, sarah_agent, d
@pytest.mark.asyncio
async def test_run_metrics_not_found(server: SyncServer, default_user):
"""Test getting metrics for non-existent run."""
with pytest.raises(NoResultFound):
with pytest.raises(LettaInvalidArgumentError):
await server.run_manager.get_run_metrics_async(run_id="nonexistent_run", actor=default_user)

View File

@@ -44,7 +44,7 @@ from letta.constants import (
MULTI_AGENT_TOOLS,
)
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
from letta.errors import LettaAgentNotFoundError
from letta.errors import LettaAgentNotFoundError, LettaInvalidArgumentError
from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.functions.mcp_client.types import MCPTool
from letta.helpers import ToolRulesSolver
@@ -174,21 +174,21 @@ async def test_detach_source(server: SyncServer, sarah_agent, default_source, de
async def test_attach_source_nonexistent_agent(server: SyncServer, default_source, default_user):
"""Test attaching a source to a nonexistent agent."""
with pytest.raises(NoResultFound):
await server.agent_manager.attach_source_async(agent_id="nonexistent-agent-id", source_id=default_source.id, actor=default_user)
await server.agent_manager.attach_source_async(agent_id=f"agent-{uuid.uuid4()}", source_id=default_source.id, actor=default_user)
@pytest.mark.asyncio
async def test_attach_source_nonexistent_source(server: SyncServer, sarah_agent, default_user):
"""Test attaching a nonexistent source to an agent."""
with pytest.raises(NoResultFound):
await server.agent_manager.attach_source_async(agent_id=sarah_agent.id, source_id="nonexistent-source-id", actor=default_user)
await server.agent_manager.attach_source_async(agent_id=sarah_agent.id, source_id=f"source-{uuid.uuid4()}", actor=default_user)
@pytest.mark.asyncio
async def test_detach_source_nonexistent_agent(server: SyncServer, default_source, default_user):
"""Test detaching a source from a nonexistent agent."""
with pytest.raises(LettaAgentNotFoundError):
await server.agent_manager.detach_source_async(agent_id="nonexistent-agent-id", source_id=default_source.id, actor=default_user)
await server.agent_manager.detach_source_async(agent_id=f"agent-{uuid.uuid4()}", source_id=default_source.id, actor=default_user)
@pytest.mark.asyncio
@@ -234,7 +234,7 @@ async def test_list_attached_agents(server: SyncServer, sarah_agent, charles_age
async def test_list_attached_agents_nonexistent_source(server: SyncServer, default_user):
"""Test listing agents for a nonexistent source."""
with pytest.raises(NoResultFound):
with pytest.raises(LettaInvalidArgumentError):
await server.source_manager.list_attached_agents(source_id="nonexistent-source-id", actor=default_user)

View File

@@ -220,7 +220,7 @@ async def test_bulk_detach_tools_idempotent(server: SyncServer, sarah_agent, pri
@pytest.mark.asyncio
async def test_bulk_detach_tools_nonexistent_agent(server: SyncServer, print_tool, other_tool, default_user):
"""Test bulk detaching tools from a nonexistent agent."""
nonexistent_agent_id = "nonexistent-agent-id"
nonexistent_agent_id = f"agent-{uuid.uuid4()}"
tool_ids = [print_tool.id, other_tool.id]
with pytest.raises(LettaAgentNotFoundError):
@@ -230,19 +230,19 @@ async def test_bulk_detach_tools_nonexistent_agent(server: SyncServer, print_too
async def test_attach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user):
"""Test attaching a tool to a nonexistent agent."""
with pytest.raises(LettaAgentNotFoundError):
await server.agent_manager.attach_tool_async(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user)
await server.agent_manager.attach_tool_async(agent_id=f"agent-{uuid.uuid4()}", tool_id=print_tool.id, actor=default_user)
async def test_attach_tool_nonexistent_tool(server: SyncServer, sarah_agent, default_user):
"""Test attaching a nonexistent tool to an agent."""
with pytest.raises(NoResultFound):
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id="nonexistent-tool-id", actor=default_user)
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=f"tool-{uuid.uuid4()}", actor=default_user)
async def test_detach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user):
"""Test detaching a tool from a nonexistent agent."""
with pytest.raises(LettaAgentNotFoundError):
await server.agent_manager.detach_tool_async(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user)
await server.agent_manager.detach_tool_async(agent_id=f"agent-{uuid.uuid4()}", tool_id=print_tool.id, actor=default_user)
@pytest.mark.asyncio
@@ -330,7 +330,7 @@ async def test_bulk_attach_tools_nonexistent_tool(server: SyncServer, sarah_agen
@pytest.mark.asyncio
async def test_bulk_attach_tools_nonexistent_agent(server: SyncServer, print_tool, other_tool, default_user):
"""Test bulk attaching tools to a nonexistent agent."""
nonexistent_agent_id = "nonexistent-agent-id"
nonexistent_agent_id = f"agent-{uuid.uuid4()}"
tool_ids = [print_tool.id, other_tool.id]
with pytest.raises(LettaAgentNotFoundError):
@@ -2187,7 +2187,8 @@ async def test_list_tools_with_corrupted_tool(server: SyncServer, default_user,
from letta.orm.tool import Tool as ToolModel
async with db_registry.async_session() as session:
# Create a tool with no json_schema (corrupted state)
# Create a tool with corrupted ID format (bypassing validation)
# This simulates a tool that somehow got corrupted in the database
corrupted_tool = ToolModel(
id=f"tool-corrupted-{uuid.uuid4()}",
name="corrupted_tool",