feat: runtime validation for ids for internal managers calls (#5544)
* claude coded first pass * fix test cases to expect errors instead * fix this * let's see how letta-code did * claude * fix tests, remove dangling comments, retrofit all managers functions with decorator * revert to main for these since we are not erroring on invalid tool and block ids * reorder decorators * finish refactoring test cases * reorder agent_manager decorators and fix test tool manager * add decorator on missing managers * fix id sources * remove redundant check * uses enum now * move to enum
This commit is contained in:
@@ -8,6 +8,7 @@ from letta.constants import CORE_MEMORY_LINE_NUMBER_WARNING, DEFAULT_EMBEDDING_C
|
||||
from letta.errors import AgentExportProcessingError
|
||||
from letta.schemas.block import Block, CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.environment_variables import AgentEnvironmentVariable
|
||||
from letta.schemas.file import FileStatus
|
||||
from letta.schemas.group import Group
|
||||
@@ -57,7 +58,7 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
|
||||
embedding_config (EmbeddingConfig): The embedding configuration used by the agent.
|
||||
"""
|
||||
|
||||
__id_prefix__ = "agent"
|
||||
__id_prefix__ = PrimitiveType.AGENT.value
|
||||
|
||||
# NOTE: this is what is returned to the client and also what is used to initialize `Agent`
|
||||
id: str = Field(..., description="The id of the agent. Assigned by the database.")
|
||||
|
||||
@@ -3,12 +3,12 @@ from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.enums import VectorDBProvider
|
||||
from letta.schemas.enums import PrimitiveType, VectorDBProvider
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
|
||||
|
||||
class ArchiveBase(OrmMetadataBase):
|
||||
__id_prefix__ = "archive"
|
||||
__id_prefix__ = PrimitiveType.ARCHIVE.value
|
||||
|
||||
name: str = Field(..., description="The name of the archive")
|
||||
description: Optional[str] = Field(None, description="A description of the archive")
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Any, Optional
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
|
||||
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT, DEFAULT_HUMAN_BLOCK_DESCRIPTION, DEFAULT_PERSONA_BLOCK_DESCRIPTION
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
# block of the LLM context
|
||||
@@ -12,7 +13,7 @@ from letta.schemas.letta_base import LettaBase
|
||||
class BaseBlock(LettaBase, validate_assignment=True):
|
||||
"""Base block of the LLM context"""
|
||||
|
||||
__id_prefix__ = "block"
|
||||
__id_prefix__ = PrimitiveType.BLOCK.value
|
||||
|
||||
# data value
|
||||
value: str = Field(..., description="Value of the block.")
|
||||
|
||||
@@ -1,6 +1,31 @@
|
||||
from enum import Enum, StrEnum
|
||||
|
||||
|
||||
class PrimitiveType(str, Enum):
|
||||
"""
|
||||
Enum for all primitive resource types in Letta.
|
||||
|
||||
The enum values ARE the actual ID prefixes used in the system.
|
||||
This serves as the single source of truth for all ID prefixes.
|
||||
"""
|
||||
|
||||
AGENT = "agent"
|
||||
MESSAGE = "message"
|
||||
RUN = "run"
|
||||
JOB = "job"
|
||||
GROUP = "group"
|
||||
BLOCK = "block"
|
||||
FILE = "file"
|
||||
FOLDER = "source" # Note: folder IDs use "source" prefix for historical reasons
|
||||
SOURCE = "source"
|
||||
TOOL = "tool"
|
||||
ARCHIVE = "archive"
|
||||
PROVIDER = "provider"
|
||||
SANDBOX_CONFIG = "sandbox" # Note: sandbox_config IDs use "sandbox" prefix
|
||||
STEP = "step"
|
||||
IDENTITY = "identity"
|
||||
|
||||
|
||||
class ProviderType(str, Enum):
|
||||
anthropic = "anthropic"
|
||||
azure = "azure"
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.enums import FileProcessingStatus
|
||||
from letta.schemas.enums import FileProcessingStatus, PrimitiveType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ class FileStatus(str, Enum):
|
||||
class FileMetadataBase(LettaBase):
|
||||
"""Base class for FileMetadata schemas"""
|
||||
|
||||
__id_prefix__ = "file"
|
||||
__id_prefix__ = PrimitiveType.FILE.value
|
||||
|
||||
# Core file metadata fields
|
||||
source_id: str = Field(..., description="The unique identifier of the source associated with the document.")
|
||||
@@ -61,7 +61,7 @@ class FileMetadata(FileMetadataBase):
|
||||
class FileAgentBase(LettaBase):
|
||||
"""Base class for the FileMetadata-⇄-Agent association schemas"""
|
||||
|
||||
__id_prefix__ = "file_agent"
|
||||
__id_prefix__ = PrimitiveType.FILE.value
|
||||
|
||||
# Core file-agent association fields
|
||||
agent_id: str = Field(..., description="Unique identifier of the agent.")
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
@@ -12,7 +13,7 @@ class BaseFolder(LettaBase):
|
||||
Shared attributes across all folder schemas.
|
||||
"""
|
||||
|
||||
__id_prefix__ = "source" # TODO: change to "folder"
|
||||
__id_prefix__ = PrimitiveType.FOLDER.value # TODO: change to "folder"
|
||||
|
||||
# Core folder fields
|
||||
name: str = Field(..., description="The name of the folder.")
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
@@ -20,7 +21,7 @@ class ManagerConfig(BaseModel):
|
||||
|
||||
|
||||
class GroupBase(LettaBase):
|
||||
__id_prefix__ = "group"
|
||||
__id_prefix__ = PrimitiveType.GROUP.value
|
||||
|
||||
|
||||
class Group(GroupBase):
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import List, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
@@ -28,7 +29,7 @@ class IdentityPropertyType(str, Enum):
|
||||
|
||||
|
||||
class IdentityBase(LettaBase):
|
||||
__id_prefix__ = "identity"
|
||||
__id_prefix__ = PrimitiveType.IDENTITY.value
|
||||
|
||||
|
||||
class IdentityProperty(LettaBase):
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.schemas.letta_request import LettaRequest
|
||||
|
||||
@@ -15,7 +17,7 @@ from letta.schemas.letta_stop_reason import StopReasonType
|
||||
|
||||
|
||||
class JobBase(OrmMetadataBase):
|
||||
__id_prefix__ = "job"
|
||||
__id_prefix__ = PrimitiveType.JOB.value
|
||||
status: JobStatus = Field(default=JobStatus.created, description="The status of the job.")
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the job was created.")
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, RE
|
||||
from letta.helpers.datetime_helpers import get_utc_time, is_utc_datetime
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_VERTEX
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.enums import MessageRole, PrimitiveType
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.letta_message import (
|
||||
ApprovalRequestMessage,
|
||||
@@ -170,7 +170,7 @@ class MessageUpdate(BaseModel):
|
||||
|
||||
|
||||
class BaseMessage(OrmMetadataBase):
|
||||
__id_prefix__ = "message"
|
||||
__id_prefix__ = PrimitiveType.MESSAGE.value
|
||||
|
||||
|
||||
class Message(BaseMessage):
|
||||
|
||||
@@ -4,7 +4,7 @@ from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.enums import PrimitiveType, ProviderCategory, ProviderType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.llm_config_overrides import LLM_HANDLE_OVERRIDES
|
||||
@@ -13,7 +13,7 @@ from letta.settings import model_settings
|
||||
|
||||
|
||||
class ProviderBase(LettaBase):
|
||||
__id_prefix__ = "provider"
|
||||
__id_prefix__ = PrimitiveType.PROVIDER.value
|
||||
|
||||
|
||||
class Provider(ProviderBase):
|
||||
|
||||
@@ -4,14 +4,14 @@ from typing import Optional
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.schemas.enums import RunStatus
|
||||
from letta.schemas.enums import PrimitiveType, RunStatus
|
||||
from letta.schemas.job import LettaRequestConfig
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
|
||||
|
||||
class RunBase(LettaBase):
|
||||
__id_prefix__ = "run"
|
||||
__id_prefix__ = PrimitiveType.RUN.value
|
||||
|
||||
|
||||
class Run(RunBase):
|
||||
|
||||
@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from letta.constants import LETTA_TOOL_EXECUTION_DIR
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import SandboxType
|
||||
from letta.schemas.enums import PrimitiveType, SandboxType
|
||||
from letta.schemas.letta_base import LettaBase, OrmMetadataBase
|
||||
from letta.schemas.pip_requirement import PipRequirement
|
||||
from letta.services.tool_sandbox.modal_constants import DEFAULT_MODAL_TIMEOUT
|
||||
@@ -92,7 +92,7 @@ class ModalSandboxConfig(BaseModel):
|
||||
|
||||
|
||||
class SandboxConfigBase(OrmMetadataBase):
|
||||
__id_prefix__ = "sandbox"
|
||||
__id_prefix__ = PrimitiveType.SANDBOX_CONFIG.value
|
||||
|
||||
|
||||
class SandboxConfig(SandboxConfigBase):
|
||||
|
||||
@@ -5,7 +5,7 @@ from pydantic import Field
|
||||
|
||||
from letta.helpers.tpuf_client import should_use_tpuf
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import VectorDBProvider
|
||||
from letta.schemas.enums import PrimitiveType, VectorDBProvider
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ class BaseSource(LettaBase):
|
||||
Shared attributes across all source schemas.
|
||||
"""
|
||||
|
||||
__id_prefix__ = "source"
|
||||
__id_prefix__ = PrimitiveType.SOURCE.value
|
||||
|
||||
# Core source fields
|
||||
name: str = Field(..., description="The name of the source.")
|
||||
|
||||
@@ -3,14 +3,14 @@ from typing import Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.enums import StepStatus
|
||||
from letta.schemas.enums import PrimitiveType, StepStatus
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
from letta.schemas.message import Message
|
||||
|
||||
|
||||
class StepBase(LettaBase):
|
||||
__id_prefix__ = "step"
|
||||
__id_prefix__ = PrimitiveType.STEP.value
|
||||
|
||||
|
||||
class Step(StepBase):
|
||||
|
||||
@@ -11,6 +11,7 @@ from letta.constants import (
|
||||
LETTA_VOICE_TOOL_MODULE_NAME,
|
||||
MCP_TOOL_TAG_NAME_PREFIX,
|
||||
)
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
|
||||
# MCP Tool metadata constants for schema health status
|
||||
MCP_TOOL_METADATA_SCHEMA_STATUS = f"{MCP_TOOL_TAG_NAME_PREFIX}:SCHEMA_STATUS"
|
||||
@@ -28,7 +29,7 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BaseTool(LettaBase):
|
||||
__id_prefix__ = "tool"
|
||||
__id_prefix__ = PrimitiveType.TOOL.value
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
|
||||
@@ -58,7 +58,7 @@ from letta.schemas.agent import (
|
||||
)
|
||||
from letta.schemas.block import DEFAULT_BLOCKS, Block as PydanticBlock, BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import AgentType, ProviderType, TagMatchMode, ToolType, VectorDBProvider
|
||||
from letta.schemas.enums import AgentType, PrimitiveType, ProviderType, TagMatchMode, ToolType, VectorDBProvider
|
||||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from letta.schemas.group import Group as PydanticGroup, ManagerType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
@@ -110,7 +110,7 @@ from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import DatabaseChoice, model_settings, settings
|
||||
from letta.utils import calculate_file_defaults_based_on_context_window, enforce_types, united_diff
|
||||
from letta.validators import is_valid_id
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -672,6 +672,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def update_agent_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -976,6 +977,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def get_agent_by_id_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -984,10 +986,6 @@ class AgentManager:
|
||||
) -> PydanticAgentState:
|
||||
"""Fetch an agent by its ID."""
|
||||
|
||||
# Check if agent_id matches uuid4 format
|
||||
if not is_valid_id("agent", agent_id):
|
||||
raise LettaAgentNotFoundError(f"agent_id {agent_id} is not in the valid format 'agent-<uuid4>'")
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
query = select(AgentModel)
|
||||
@@ -1039,6 +1037,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def get_agent_archive_ids_async(self, agent_id: str, actor: PydanticUser) -> List[str]:
|
||||
"""Get all archive IDs associated with an agent."""
|
||||
from letta.orm import ArchivesAgents
|
||||
@@ -1052,6 +1051,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def validate_agent_exists_async(self, agent_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
Validate that an agent exists and user has access to it.
|
||||
@@ -1069,6 +1069,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def delete_agent_async(self, agent_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
Deletes an agent and its associated relationships.
|
||||
@@ -1131,6 +1132,7 @@ class AgentManager:
|
||||
# TODO: This can also be made more efficient, instead of getting, setting, we can do it all in one db session for one query.
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def get_in_context_messages(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
|
||||
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
return await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
|
||||
@@ -1143,6 +1145,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def get_system_message_async(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
|
||||
agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor)
|
||||
return await self.message_manager.get_message_by_id_async(message_id=agent.message_ids[0], actor=actor)
|
||||
@@ -1324,6 +1327,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def set_in_context_messages_async(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
|
||||
return await self.update_agent_async(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
|
||||
|
||||
@@ -1434,6 +1438,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def update_memory_if_changed_async(self, agent_id: str, new_memory: Memory, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Update internal memory object and system prompt if there have been modifications.
|
||||
@@ -1546,6 +1551,8 @@ class AgentManager:
|
||||
# ======================================================================================================================
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
async def attach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Attaches a source to an agent.
|
||||
@@ -1615,6 +1622,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def append_system_message_async(self, agent_id: str, content: str, actor: PydanticUser):
|
||||
"""
|
||||
Async version of append_system_message.
|
||||
@@ -1702,6 +1710,8 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
async def detach_source_async(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Detaches a source from an agent.
|
||||
@@ -1787,6 +1797,8 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
async def attach_block_async(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""Attaches a block to an agent. For sleeptime agents, also attaches to paired agents in the same group."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -2373,8 +2385,9 @@ class AgentManager:
|
||||
# Tool Management
|
||||
# ======================================================================================================================
|
||||
@enforce_types
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="tool_id", expected_prefix=PrimitiveType.TOOL)
|
||||
async def attach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
Attaches a tool to an agent.
|
||||
@@ -2443,6 +2456,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def bulk_attach_tools_async(self, agent_id: str, tool_ids: List[str], actor: PydanticUser) -> None:
|
||||
"""
|
||||
Efficiently attaches multiple tools to an agent in a single operation.
|
||||
@@ -2608,6 +2622,8 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="tool_id", expected_prefix=PrimitiveType.TOOL)
|
||||
async def detach_tool_async(self, agent_id: str, tool_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
Detaches a tool from an agent.
|
||||
@@ -2637,6 +2653,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def bulk_detach_tools_async(self, agent_id: str, tool_ids: List[str], actor: PydanticUser) -> None:
|
||||
"""
|
||||
Efficiently detaches multiple tools from an agent in a single operation.
|
||||
@@ -2673,6 +2690,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def modify_approvals_async(self, agent_id: str, tool_name: str, requires_approval: bool, actor: PydanticUser) -> None:
|
||||
def is_target_rule(rule):
|
||||
return rule.tool_name == tool_name and rule.type == "requires_approval"
|
||||
@@ -3021,6 +3039,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def get_agent_files_config_async(self, agent_id: str, actor: PydanticUser) -> Tuple[int, int]:
|
||||
"""Get per_file_view_window_char_limit and max_files_open for an agent.
|
||||
|
||||
@@ -3077,6 +3096,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def get_agent_max_files_open_async(self, agent_id: str, actor: PydanticUser) -> int:
|
||||
"""Get max_files_open for an agent.
|
||||
|
||||
@@ -3105,6 +3125,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def get_agent_per_file_view_window_char_limit_async(self, agent_id: str, actor: PydanticUser) -> int:
|
||||
"""Get per_file_view_window_char_limit for an agent.
|
||||
|
||||
@@ -3131,7 +3152,9 @@ class AgentManager:
|
||||
|
||||
return row
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def get_context_window(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview:
|
||||
agent_state, system_message, num_messages, num_archival_memories = await self.rebuild_system_prompt_async(
|
||||
agent_id=agent_id, actor=actor, force=True, dry_run=True
|
||||
|
||||
@@ -7,11 +7,12 @@ from letta.log import get_logger
|
||||
from letta.orm import ArchivalPassage, Archive as ArchiveModel, ArchivesAgents
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.archive import Archive as PydanticArchive
|
||||
from letta.schemas.enums import VectorDBProvider
|
||||
from letta.schemas.enums import PrimitiveType, VectorDBProvider
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.settings import settings
|
||||
from letta.utils import enforce_types
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -47,6 +48,7 @@ class ArchiveManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
|
||||
async def get_archive_by_id_async(
|
||||
self,
|
||||
archive_id: str,
|
||||
@@ -63,6 +65,7 @@ class ArchiveManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
|
||||
async def update_archive_async(
|
||||
self,
|
||||
archive_id: str,
|
||||
@@ -89,6 +92,7 @@ class ArchiveManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def list_archives_async(
|
||||
self,
|
||||
*,
|
||||
@@ -136,6 +140,8 @@ class ArchiveManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
|
||||
async def attach_agent_to_archive_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -172,6 +178,7 @@ class ArchiveManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def get_default_archive_for_agent_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -204,6 +211,7 @@ class ArchiveManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
|
||||
async def delete_archive_async(
|
||||
self,
|
||||
archive_id: str,
|
||||
@@ -221,6 +229,7 @@ class ArchiveManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def get_or_create_default_archive_for_agent_async(
|
||||
self,
|
||||
agent_id: str,
|
||||
@@ -291,6 +300,7 @@ class ArchiveManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
|
||||
async def get_agents_for_archive_async(
|
||||
self,
|
||||
archive_id: str,
|
||||
@@ -333,6 +343,7 @@ class ArchiveManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="archive_id", expected_prefix=PrimitiveType.ARCHIVE)
|
||||
async def get_or_set_vector_db_namespace_async(
|
||||
self,
|
||||
archive_id: str,
|
||||
|
||||
@@ -15,11 +15,12 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.block import Block as PydanticBlock, BlockUpdate
|
||||
from letta.schemas.enums import ActorType
|
||||
from letta.schemas.enums import ActorType, PrimitiveType
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
from letta.utils import enforce_types
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -134,10 +135,9 @@ class BlockManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
async def update_block_async(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
|
||||
"""Update a block by its ID with the given BlockUpdate object."""
|
||||
# Safety check for block
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
|
||||
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
@@ -155,6 +155,7 @@ class BlockManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
async def delete_block_async(self, block_id: str, actor: PydanticUser) -> None:
|
||||
"""Delete a block by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -353,6 +354,7 @@ class BlockManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
async def get_block_by_id_async(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
|
||||
"""Retrieve a block by its name."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -412,6 +414,7 @@ class BlockManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
async def get_agents_for_block_async(
|
||||
self,
|
||||
block_id: str,
|
||||
@@ -595,6 +598,8 @@ class BlockManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def checkpoint_block_async(
|
||||
self,
|
||||
block_id: str,
|
||||
@@ -703,6 +708,7 @@ class BlockManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
async def undo_checkpoint_block(
|
||||
self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None
|
||||
) -> PydanticBlock:
|
||||
@@ -753,6 +759,7 @@ class BlockManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
||||
async def redo_checkpoint_block(
|
||||
self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None
|
||||
) -> PydanticBlock:
|
||||
|
||||
@@ -15,7 +15,7 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.orm.file import FileContent as FileContentModel, FileMetadata as FileMetadataModel
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.enums import FileProcessingStatus
|
||||
from letta.schemas.enums import FileProcessingStatus, PrimitiveType
|
||||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.source_metadata import FileStats, OrganizationSourcesStats, SourceStats
|
||||
@@ -23,6 +23,7 @@ from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.settings import settings
|
||||
from letta.utils import enforce_types
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -93,6 +94,7 @@ class FileManager:
|
||||
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
|
||||
# @async_redis_cache(
|
||||
# key_func=lambda self, file_id, actor=None, include_content=False, strip_directory_prefix=False: f"{file_id}:{actor.organization_id if actor else 'none'}:{include_content}:{strip_directory_prefix}",
|
||||
# prefix="file_content",
|
||||
@@ -135,6 +137,7 @@ class FileManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
|
||||
async def update_file_status(
|
||||
self,
|
||||
*,
|
||||
@@ -171,7 +174,6 @@ class FileManager:
|
||||
* 1st round-trip → UPDATE with optional state validation
|
||||
* 2nd round-trip → SELECT fresh row (same as read_async) if update succeeded
|
||||
"""
|
||||
|
||||
if processing_status is None and error_message is None and total_chunks is None and chunks_embedded is None:
|
||||
raise ValueError("Nothing to update")
|
||||
|
||||
@@ -353,6 +355,7 @@ class FileManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
|
||||
async def upsert_file_content(
|
||||
self,
|
||||
*,
|
||||
@@ -398,6 +401,7 @@ class FileManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
async def list_files(
|
||||
self,
|
||||
source_id: str,
|
||||
@@ -455,6 +459,7 @@ class FileManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="file_id", expected_prefix=PrimitiveType.FILE)
|
||||
async def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata:
|
||||
"""Delete a file by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -509,6 +514,7 @@ class FileManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
# @async_redis_cache(
|
||||
# key_func=lambda self, original_filename, source_id, actor: f"{original_filename}:{source_id}:{actor.organization_id}",
|
||||
# prefix="file_by_name",
|
||||
|
||||
@@ -9,6 +9,7 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.orm.group import Group as GroupModel
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.group import Group as PydanticGroup, GroupCreate, GroupUpdate, InternalTemplateGroupCreate, ManagerType
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
@@ -16,6 +17,7 @@ from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
from letta.utils import enforce_types
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
|
||||
class GroupManager:
|
||||
@@ -62,6 +64,7 @@ class GroupManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
async def retrieve_group_async(self, group_id: str, actor: PydanticUser) -> PydanticGroup:
|
||||
async with db_registry.async_session() as session:
|
||||
group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor)
|
||||
@@ -119,6 +122,7 @@ class GroupManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
async def modify_group_async(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup:
|
||||
async with db_registry.async_session() as session:
|
||||
group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor)
|
||||
@@ -182,6 +186,7 @@ class GroupManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
async def delete_group_async(self, group_id: str, actor: PydanticUser) -> None:
|
||||
async with db_registry.async_session() as session:
|
||||
group = await GroupModel.read_async(db_session=session, identifier=group_id, actor=actor)
|
||||
@@ -189,6 +194,7 @@ class GroupManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
async def list_group_messages_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -226,6 +232,7 @@ class GroupManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
async def reset_messages_async(self, group_id: str, actor: PydanticUser) -> None:
|
||||
async with db_registry.async_session() as session:
|
||||
# Ensure group is loadable by user
|
||||
@@ -241,6 +248,7 @@ class GroupManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
async def bump_turns_counter_async(self, group_id: str, actor: PydanticUser) -> int:
|
||||
async with db_registry.async_session() as session:
|
||||
# Ensure group is loadable by user
|
||||
@@ -253,6 +261,8 @@ class GroupManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="group_id", expected_prefix=PrimitiveType.GROUP)
|
||||
@raise_on_invalid_id(param_name="last_processed_message_id", expected_prefix=PrimitiveType.MESSAGE)
|
||||
async def get_last_processed_message_id_and_update_async(
|
||||
self, group_id: str, last_processed_message_id: str, actor: PydanticUser
|
||||
) -> str:
|
||||
|
||||
@@ -12,6 +12,7 @@ from letta.orm.identity import Identity as IdentityModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.identity import (
|
||||
Identity as PydanticIdentity,
|
||||
IdentityCreate,
|
||||
@@ -24,6 +25,7 @@ from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
from letta.utils import enforce_types
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
|
||||
class IdentityManager:
|
||||
@@ -62,6 +64,7 @@ class IdentityManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
||||
async def get_identity_async(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity:
|
||||
async with db_registry.async_session() as session:
|
||||
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
|
||||
@@ -143,6 +146,7 @@ class IdentityManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
||||
async def update_identity_async(
|
||||
self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False
|
||||
) -> PydanticIdentity:
|
||||
@@ -206,6 +210,7 @@ class IdentityManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
||||
async def upsert_identity_properties_async(
|
||||
self, identity_id: str, properties: List[IdentityProperty], actor: PydanticUser
|
||||
) -> PydanticIdentity:
|
||||
@@ -223,6 +228,7 @@ class IdentityManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
||||
async def delete_identity_async(self, identity_id: str, actor: PydanticUser) -> None:
|
||||
async with db_registry.async_session() as session:
|
||||
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
|
||||
@@ -280,6 +286,7 @@ class IdentityManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
||||
async def list_agents_for_identity_async(
|
||||
self,
|
||||
identity_id: str,
|
||||
@@ -311,6 +318,7 @@ class IdentityManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
||||
async def list_blocks_for_identity_async(
|
||||
self,
|
||||
identity_id: str,
|
||||
|
||||
@@ -14,7 +14,7 @@ from letta.orm.message import Message as MessageModel
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
from letta.orm.step import Step, Step as StepModel
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
from letta.schemas.enums import JobStatus, JobType, MessageRole
|
||||
from letta.schemas.enums import JobStatus, JobType, MessageRole, PrimitiveType
|
||||
from letta.schemas.job import BatchJob as PydanticBatchJob, Job as PydanticJob, JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
@@ -26,6 +26,7 @@ from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.helpers.agent_manager_helper import validate_agent_exists_async
|
||||
from letta.utils import enforce_types
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -70,6 +71,7 @@ class JobManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
|
||||
async def update_job_by_id_async(
|
||||
self, job_id: str, job_update: JobUpdate, actor: PydanticUser, safe_update: bool = False
|
||||
) -> PydanticJob:
|
||||
@@ -147,6 +149,7 @@ class JobManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
|
||||
async def safe_update_job_status_async(
|
||||
self,
|
||||
job_id: str,
|
||||
@@ -187,6 +190,7 @@ class JobManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
|
||||
async def get_job_by_id_async(self, job_id: str, actor: PydanticUser) -> PydanticJob:
|
||||
"""Fetch a job by its ID asynchronously."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -301,6 +305,7 @@ class JobManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
|
||||
async def delete_job_by_id_async(self, job_id: str, actor: PydanticUser) -> PydanticJob:
|
||||
"""Delete a job by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -310,6 +315,7 @@ class JobManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
||||
async def get_run_messages(
|
||||
self,
|
||||
run_id: str,
|
||||
@@ -367,6 +373,7 @@ class JobManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
||||
async def get_step_messages(
|
||||
self,
|
||||
run_id: str,
|
||||
@@ -447,6 +454,7 @@ class JobManager:
|
||||
return job
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
|
||||
async def record_ttft(self, job_id: str, ttft_ns: int, actor: PydanticUser) -> None:
|
||||
"""Record time to first token for a run"""
|
||||
try:
|
||||
@@ -459,6 +467,7 @@ class JobManager:
|
||||
logger.warning(f"Failed to record TTFT for job {job_id}: {e}")
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
|
||||
async def record_response_duration(self, job_id: str, total_duration_ns: int, actor: PydanticUser) -> None:
|
||||
"""Record total response duration for a run"""
|
||||
try:
|
||||
@@ -529,6 +538,7 @@ class JobManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
|
||||
async def get_job_steps(
|
||||
self,
|
||||
job_id: str,
|
||||
|
||||
@@ -25,6 +25,7 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.orm.mcp_oauth import MCPOAuth, OAuthSessionStatus
|
||||
from letta.orm.mcp_server import MCPServer as MCPServerModel
|
||||
from letta.orm.tool import Tool as ToolModel
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.mcp import (
|
||||
MCPOAuthSession,
|
||||
MCPOAuthSessionCreate,
|
||||
@@ -47,6 +48,7 @@ from letta.services.mcp.streamable_http_client import AsyncStreamableHTTPMCPClie
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import settings, tool_settings
|
||||
from letta.utils import enforce_types, printd, safe_create_task_with_return
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -60,6 +62,7 @@ class MCPManager:
|
||||
self.cached_mcp_servers = {} # maps id -> async connection
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
async def list_mcp_server_tools(self, mcp_server_name: str, actor: PydanticUser, agent_id: Optional[str] = None) -> List[MCPTool]:
|
||||
"""Get a list of all tools for a specific MCP server."""
|
||||
mcp_client = None
|
||||
|
||||
@@ -10,7 +10,7 @@ from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.enums import MessageRole, PrimitiveType
|
||||
from letta.schemas.letta_message import LettaMessageUpdateUnion
|
||||
from letta.schemas.letta_message_content import ImageSourceType, LettaImage, MessageContentType, TextContent
|
||||
from letta.schemas.message import Message as PydanticMessage, MessageSearchResult, MessageUpdate
|
||||
@@ -20,6 +20,7 @@ from letta.services.file_manager import FileManager
|
||||
from letta.services.helpers.agent_manager_helper import validate_agent_exists_async
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
from letta.utils import enforce_types, fire_and_forget
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -308,6 +309,7 @@ class MessageManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="message_id", expected_prefix=PrimitiveType.MESSAGE)
|
||||
async def get_message_by_id_async(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]:
|
||||
"""Fetch a message by ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -712,6 +714,7 @@ class MessageManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="message_id", expected_prefix=PrimitiveType.MESSAGE)
|
||||
async def delete_message_by_id_async(self, message_id: str, actor: PydanticUser, strict_mode: bool = False) -> bool:
|
||||
"""Delete a message (async version with turbopuffer support)."""
|
||||
# capture agent_id before deletion
|
||||
|
||||
@@ -2,12 +2,13 @@ from typing import List, Optional, Tuple, Union
|
||||
|
||||
from letta.orm.provider import Provider as ProviderModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.enums import PrimitiveType, ProviderCategory, ProviderType
|
||||
from letta.schemas.providers import Provider as PydanticProvider, ProviderCheck, ProviderCreate, ProviderUpdate
|
||||
from letta.schemas.secret import Secret
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
@@ -40,6 +41,7 @@ class ProviderManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
||||
async def update_provider_async(self, provider_id: str, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider:
|
||||
"""Update provider details."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -103,6 +105,7 @@ class ProviderManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
||||
async def delete_provider_by_id_async(self, provider_id: str, actor: PydanticUser):
|
||||
"""Delete a provider."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -153,6 +156,7 @@ class ProviderManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
||||
async def get_provider_async(self, provider_id: str, actor: PydanticUser) -> PydanticProvider:
|
||||
async with db_registry.async_session() as session:
|
||||
provider_model = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=actor)
|
||||
|
||||
@@ -14,7 +14,7 @@ from letta.orm.run_metrics import RunMetrics as RunMetricsModel
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
from letta.orm.step import Step as StepModel
|
||||
from letta.otel.tracing import log_event, trace_method
|
||||
from letta.schemas.enums import AgentType, ComparisonOperator, MessageRole, RunStatus
|
||||
from letta.schemas.enums import AgentType, ComparisonOperator, MessageRole, RunStatus, PrimitiveType
|
||||
from letta.schemas.job import LettaRequestConfig
|
||||
from letta.schemas.letta_message import LettaMessage, LettaMessageUnion
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
@@ -31,6 +31,7 @@ from letta.services.helpers.agent_manager_helper import validate_agent_exists_as
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.step_manager import StepManager
|
||||
from letta.utils import enforce_types
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -85,6 +86,7 @@ class RunManager:
|
||||
return run.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
||||
async def get_run_by_id(self, run_id: str, actor: PydanticUser) -> PydanticRun:
|
||||
"""Get a run by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -176,6 +178,7 @@ class RunManager:
|
||||
return [run.to_pydantic() for run in runs]
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
||||
async def delete_run(self, run_id: str, actor: PydanticUser) -> PydanticRun:
|
||||
"""Delete a run by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -189,11 +192,11 @@ class RunManager:
|
||||
return pydantic_run
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
||||
async def update_run_by_id_async(
|
||||
self, run_id: str, update: RunUpdate, actor: PydanticUser, refresh_result_messages: bool = True
|
||||
) -> PydanticRun:
|
||||
"""Update a run using a RunUpdate object."""
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor)
|
||||
|
||||
@@ -327,6 +330,7 @@ class RunManager:
|
||||
return result
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
||||
async def get_run_usage(self, run_id: str, actor: PydanticUser) -> LettaUsageStatistics:
|
||||
"""Get usage statistics for a run."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -344,6 +348,7 @@ class RunManager:
|
||||
return total_usage
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
||||
async def get_run_messages(
|
||||
self,
|
||||
run_id: str,
|
||||
@@ -378,6 +383,7 @@ class RunManager:
|
||||
return letta_messages
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
||||
async def get_run_request_config(self, run_id: str, actor: PydanticUser) -> Optional[LettaRequestConfig]:
|
||||
"""Get the letta request config from a run."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -388,6 +394,7 @@ class RunManager:
|
||||
return pydantic_run.request_config
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
||||
async def get_run_metrics_async(self, run_id: str, actor: PydanticUser) -> PydanticRunMetrics:
|
||||
"""Get metrics for a run."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -395,6 +402,7 @@ class RunManager:
|
||||
return metrics.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
||||
async def get_run_steps(
|
||||
self,
|
||||
run_id: str,
|
||||
|
||||
@@ -5,7 +5,7 @@ from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.sandbox_config import SandboxConfig as SandboxConfigModel, SandboxEnvironmentVariable as SandboxEnvVarModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.enums import SandboxType
|
||||
from letta.schemas.enums import PrimitiveType, SandboxType
|
||||
from letta.schemas.environment_variables import (
|
||||
SandboxEnvironmentVariable as PydanticEnvVar,
|
||||
SandboxEnvironmentVariableCreate,
|
||||
@@ -20,6 +20,7 @@ from letta.schemas.sandbox_config import (
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types, printd
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -101,6 +102,7 @@ class SandboxConfigManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
|
||||
async def update_sandbox_config_async(
|
||||
self, sandbox_config_id: str, sandbox_update: SandboxConfigUpdate, actor: PydanticUser
|
||||
) -> PydanticSandboxConfig:
|
||||
@@ -129,6 +131,7 @@ class SandboxConfigManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
|
||||
async def delete_sandbox_config_async(self, sandbox_config_id: str, actor: PydanticUser) -> PydanticSandboxConfig:
|
||||
"""Delete a sandbox configuration by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -176,6 +179,7 @@ class SandboxConfigManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
|
||||
async def create_sandbox_env_var_async(
|
||||
self, env_var_create: SandboxEnvironmentVariableCreate, sandbox_config_id: str, actor: PydanticUser
|
||||
) -> PydanticEnvVar:
|
||||
@@ -266,6 +270,7 @@ class SandboxConfigManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
|
||||
async def list_sandbox_env_vars_async(
|
||||
self,
|
||||
sandbox_config_id: str,
|
||||
@@ -302,6 +307,7 @@ class SandboxConfigManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
|
||||
def get_sandbox_env_vars_as_dict(
|
||||
self, sandbox_config_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50
|
||||
) -> Dict[str, str]:
|
||||
@@ -315,6 +321,7 @@ class SandboxConfigManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
|
||||
async def get_sandbox_env_vars_as_dict_async(
|
||||
self, sandbox_config_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50
|
||||
) -> Dict[str, str]:
|
||||
@@ -324,6 +331,7 @@ class SandboxConfigManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="sandbox_config_id", expected_prefix=PrimitiveType.SANDBOX_CONFIG)
|
||||
async def get_sandbox_env_var_by_key_and_sandbox_config_id_async(
|
||||
self, key: str, sandbox_config_id: str, actor: Optional[PydanticUser] = None
|
||||
) -> Optional[PydanticEnvVar]:
|
||||
|
||||
@@ -11,11 +11,12 @@ from letta.orm.source import Source as SourceModel
|
||||
from letta.orm.sources_agents import SourcesAgents
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.enums import VectorDBProvider
|
||||
from letta.schemas.enums import PrimitiveType, VectorDBProvider
|
||||
from letta.schemas.source import Source as PydanticSource, SourceUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types, printd
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
|
||||
class SourceManager:
|
||||
@@ -201,6 +202,7 @@ class SourceManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
async def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource:
|
||||
"""Update a source by its ID with the given SourceUpdate object."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -224,6 +226,7 @@ class SourceManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
async def delete_source(self, source_id: str, actor: PydanticUser) -> PydanticSource:
|
||||
"""Delete a source by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -268,6 +271,7 @@ class SourceManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
async def list_attached_agents(
|
||||
self, source_id: str, actor: PydanticUser, ids_only: bool = False
|
||||
) -> Union[List[PydanticAgentState], List[str]]:
|
||||
@@ -321,6 +325,7 @@ class SourceManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
async def get_agents_for_source_id(self, source_id: str, actor: PydanticUser) -> List[str]:
|
||||
"""
|
||||
Get all agent IDs associated with a given source ID.
|
||||
@@ -347,6 +352,7 @@ class SourceManager:
|
||||
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="source_id", expected_prefix=PrimitiveType.SOURCE)
|
||||
async def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]:
|
||||
"""Retrieve a source by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
|
||||
@@ -13,7 +13,7 @@ from letta.orm.sqlalchemy_base import AccessType
|
||||
from letta.orm.step import Step as StepModel
|
||||
from letta.orm.step_metrics import StepMetrics as StepMetricsModel
|
||||
from letta.otel.tracing import get_trace_id, trace_method
|
||||
from letta.schemas.enums import StepStatus
|
||||
from letta.schemas.enums import PrimitiveType, StepStatus
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
@@ -22,6 +22,7 @@ from letta.schemas.step_metrics import StepMetrics as PydanticStepMetrics
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
|
||||
class FeedbackType(str, Enum):
|
||||
@@ -32,6 +33,8 @@ class FeedbackType(str, Enum):
|
||||
class StepManager:
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
||||
async def list_steps_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -79,6 +82,10 @@ class StepManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
||||
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
def log_step(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -133,6 +140,10 @@ class StepManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
|
||||
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
async def log_step_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -196,6 +207,7 @@ class StepManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
async def get_step_async(self, step_id: str, actor: PydanticUser) -> PydanticStep:
|
||||
async with db_registry.async_session() as session:
|
||||
step = await StepModel.read_async(db_session=session, identifier=step_id, actor=actor)
|
||||
@@ -203,6 +215,7 @@ class StepManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
async def get_step_metrics_async(self, step_id: str, actor: PydanticUser) -> PydanticStepMetrics:
|
||||
async with db_registry.async_session() as session:
|
||||
metrics = await StepMetricsModel.read_async(db_session=session, identifier=step_id, actor=actor)
|
||||
@@ -210,6 +223,7 @@ class StepManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
async def add_feedback_async(
|
||||
self, step_id: str, feedback: FeedbackType | None, actor: PydanticUser, tags: list[str] | None = None
|
||||
) -> PydanticStep:
|
||||
@@ -225,6 +239,7 @@ class StepManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
async def update_step_transaction_id(self, actor: PydanticUser, step_id: str, transaction_id: str) -> PydanticStep:
|
||||
"""Update the transaction ID for a step.
|
||||
|
||||
@@ -252,6 +267,7 @@ class StepManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
async def list_step_messages_async(
|
||||
self,
|
||||
step_id: str,
|
||||
@@ -276,6 +292,7 @@ class StepManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
async def update_step_stop_reason(self, actor: PydanticUser, step_id: str, stop_reason: StopReasonType) -> PydanticStep:
|
||||
"""Update the stop reason for a step.
|
||||
|
||||
@@ -303,6 +320,7 @@ class StepManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
async def update_step_error_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -348,6 +366,7 @@ class StepManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
async def update_step_success_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -388,6 +407,7 @@ class StepManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
async def update_step_cancelled_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
@@ -423,6 +443,9 @@ class StepManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="step_id", expected_prefix=PrimitiveType.STEP)
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
||||
async def record_step_metrics_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
|
||||
@@ -27,7 +27,7 @@ from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.tool import Tool as ToolModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.enums import ToolType
|
||||
from letta.schemas.enums import PrimitiveType, ToolType
|
||||
from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
@@ -36,6 +36,7 @@ from letta.services.mcp.types import SSEServerConfig, StdioServerConfig
|
||||
from letta.services.tool_schema_generator import generate_schema_for_tool_creation, generate_schema_for_tool_update
|
||||
from letta.settings import settings
|
||||
from letta.utils import enforce_types, printd
|
||||
from letta.validators import raise_on_invalid_id
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -203,6 +204,7 @@ class ToolManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="tool_id", expected_prefix=PrimitiveType.TOOL)
|
||||
async def get_tool_by_id_async(self, tool_id: str, actor: PydanticUser) -> PydanticTool:
|
||||
"""Fetch a tool by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -235,6 +237,7 @@ class ToolManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="tool_id", expected_prefix=PrimitiveType.TOOL)
|
||||
async def tool_exists_async(self, tool_id: str, actor: PydanticUser) -> bool:
|
||||
"""Check if a tool exists and belongs to the user's organization (lightweight check)."""
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -507,6 +510,7 @@ class ToolManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@raise_on_invalid_id(param_name="tool_id", expected_prefix=PrimitiveType.TOOL)
|
||||
async def update_tool_by_id_async(
|
||||
self,
|
||||
tool_id: str,
|
||||
@@ -604,6 +608,7 @@ class ToolManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
# @raise_on_invalid_id This is commented out bc it's called by _list_tools_async, when it encounters malformed tools (i.e. if id is invalid will fail validation on deletion)
|
||||
async def delete_tool_by_id_async(self, tool_id: str, actor: PydanticUser) -> None:
|
||||
"""Delete a tool by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
|
||||
@@ -1,48 +1,21 @@
|
||||
import inspect
|
||||
import re
|
||||
from functools import wraps
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Path
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.archive import ArchiveBase
|
||||
from letta.schemas.block import BaseBlock
|
||||
from letta.schemas.file import FileMetadataBase
|
||||
from letta.schemas.folder import BaseFolder
|
||||
from letta.schemas.group import GroupBase
|
||||
from letta.schemas.identity import IdentityBase
|
||||
from letta.schemas.job import JobBase
|
||||
from letta.schemas.message import BaseMessage
|
||||
from letta.schemas.providers import ProviderBase
|
||||
from letta.schemas.run import RunBase
|
||||
from letta.schemas.sandbox_config import SandboxConfigBase
|
||||
from letta.schemas.source import BaseSource
|
||||
from letta.schemas.step import StepBase
|
||||
from letta.schemas.tool import BaseTool
|
||||
from letta.errors import LettaInvalidArgumentError
|
||||
from letta.schemas.enums import PrimitiveType # PrimitiveType is now in schemas.enums
|
||||
|
||||
# TODO: extract this list from routers/v1/__init__.py and ROUTERS
|
||||
primitives = [
|
||||
AgentState.__id_prefix__,
|
||||
BaseMessage.__id_prefix__,
|
||||
RunBase.__id_prefix__,
|
||||
JobBase.__id_prefix__,
|
||||
GroupBase.__id_prefix__,
|
||||
BaseBlock.__id_prefix__,
|
||||
FileMetadataBase.__id_prefix__,
|
||||
BaseFolder.__id_prefix__,
|
||||
BaseSource.__id_prefix__,
|
||||
BaseTool.__id_prefix__,
|
||||
ArchiveBase.__id_prefix__,
|
||||
ProviderBase.__id_prefix__,
|
||||
SandboxConfigBase.__id_prefix__,
|
||||
StepBase.__id_prefix__,
|
||||
IdentityBase.__id_prefix__,
|
||||
]
|
||||
# Map from PrimitiveType to the actual prefix string (which is just the enum value)
|
||||
PRIMITIVE_ID_PREFIXES = {primitive_type: primitive_type.value for primitive_type in PrimitiveType}
|
||||
|
||||
|
||||
PRIMITIVE_ID_PATTERNS = {
|
||||
# f-string interpolation gets confused because of the regex's required curly braces {}
|
||||
primitive: re.compile("^" + primitive + "-[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$")
|
||||
for primitive in primitives
|
||||
prefix: re.compile("^" + prefix + "-[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$")
|
||||
for prefix in PRIMITIVE_ID_PREFIXES.values()
|
||||
}
|
||||
|
||||
|
||||
@@ -67,28 +40,65 @@ def _create_path_validator_factory(primitive: str):
|
||||
|
||||
|
||||
# PATH_VALIDATORS now contains factory functions, not Path objects
|
||||
# Usage: folder_id: str = PATH_VALIDATORS[BaseFolder.__id_prefix__]()
|
||||
PATH_VALIDATORS = {primitive: _create_path_validator_factory(primitive) for primitive in primitives}
|
||||
|
||||
|
||||
def is_valid_id(primitive: str, id: str) -> bool:
|
||||
return PRIMITIVE_ID_PATTERNS[primitive].match(id) is not None
|
||||
# Usage: folder_id: str = PATH_VALIDATORS[PrimitiveType.FOLDER.value]()
|
||||
PATH_VALIDATORS = {primitive_type.value: _create_path_validator_factory(primitive_type.value) for primitive_type in PrimitiveType}
|
||||
|
||||
|
||||
# Type aliases for common ID types
|
||||
# These can be used directly in route handler signatures for cleaner code
|
||||
AgentId = Annotated[str, PATH_VALIDATORS[AgentState.__id_prefix__]()]
|
||||
ToolId = Annotated[str, PATH_VALIDATORS[BaseTool.__id_prefix__]()]
|
||||
SourceId = Annotated[str, PATH_VALIDATORS[BaseSource.__id_prefix__]()]
|
||||
BlockId = Annotated[str, PATH_VALIDATORS[BaseBlock.__id_prefix__]()]
|
||||
MessageId = Annotated[str, PATH_VALIDATORS[BaseMessage.__id_prefix__]()]
|
||||
RunId = Annotated[str, PATH_VALIDATORS[RunBase.__id_prefix__]()]
|
||||
JobId = Annotated[str, PATH_VALIDATORS[JobBase.__id_prefix__]()]
|
||||
GroupId = Annotated[str, PATH_VALIDATORS[GroupBase.__id_prefix__]()]
|
||||
FileId = Annotated[str, PATH_VALIDATORS[FileMetadataBase.__id_prefix__]()]
|
||||
FolderId = Annotated[str, PATH_VALIDATORS[BaseFolder.__id_prefix__]()]
|
||||
ArchiveId = Annotated[str, PATH_VALIDATORS[ArchiveBase.__id_prefix__]()]
|
||||
ProviderId = Annotated[str, PATH_VALIDATORS[ProviderBase.__id_prefix__]()]
|
||||
SandboxConfigId = Annotated[str, PATH_VALIDATORS[SandboxConfigBase.__id_prefix__]()]
|
||||
StepId = Annotated[str, PATH_VALIDATORS[StepBase.__id_prefix__]()]
|
||||
IdentityId = Annotated[str, PATH_VALIDATORS[IdentityBase.__id_prefix__]()]
|
||||
AgentId = Annotated[str, PATH_VALIDATORS[PrimitiveType.AGENT.value]()]
|
||||
ToolId = Annotated[str, PATH_VALIDATORS[PrimitiveType.TOOL.value]()]
|
||||
SourceId = Annotated[str, PATH_VALIDATORS[PrimitiveType.SOURCE.value]()]
|
||||
BlockId = Annotated[str, PATH_VALIDATORS[PrimitiveType.BLOCK.value]()]
|
||||
MessageId = Annotated[str, PATH_VALIDATORS[PrimitiveType.MESSAGE.value]()]
|
||||
RunId = Annotated[str, PATH_VALIDATORS[PrimitiveType.RUN.value]()]
|
||||
JobId = Annotated[str, PATH_VALIDATORS[PrimitiveType.JOB.value]()]
|
||||
GroupId = Annotated[str, PATH_VALIDATORS[PrimitiveType.GROUP.value]()]
|
||||
FileId = Annotated[str, PATH_VALIDATORS[PrimitiveType.FILE.value]()]
|
||||
FolderId = Annotated[str, PATH_VALIDATORS[PrimitiveType.FOLDER.value]()]
|
||||
ArchiveId = Annotated[str, PATH_VALIDATORS[PrimitiveType.ARCHIVE.value]()]
|
||||
ProviderId = Annotated[str, PATH_VALIDATORS[PrimitiveType.PROVIDER.value]()]
|
||||
SandboxConfigId = Annotated[str, PATH_VALIDATORS[PrimitiveType.SANDBOX_CONFIG.value]()]
|
||||
StepId = Annotated[str, PATH_VALIDATORS[PrimitiveType.STEP.value]()]
|
||||
IdentityId = Annotated[str, PATH_VALIDATORS[PrimitiveType.IDENTITY.value]()]
|
||||
|
||||
|
||||
def raise_on_invalid_id(param_name: str, expected_prefix: PrimitiveType):
|
||||
"""
|
||||
Decorator that validates an ID parameter has the expected prefix format.
|
||||
Can be stacked multiple times on the same function to validate different IDs.
|
||||
|
||||
Args:
|
||||
param_name: The name of the function parameter to validate (e.g., "agent_id")
|
||||
expected_prefix: The expected primitive type (e.g., PrimitiveType.AGENT)
|
||||
|
||||
Example:
|
||||
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
||||
@raise_on_invalid_id(param_name="folder_id", expected_prefix=PrimitiveType.FOLDER)
|
||||
def my_function(agent_id: str, folder_id: str):
|
||||
pass
|
||||
"""
|
||||
|
||||
def decorator(function):
|
||||
@wraps(function)
|
||||
def wrapper(*args, **kwargs):
|
||||
sig = inspect.signature(function)
|
||||
bound_args = sig.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
if param_name in bound_args.arguments:
|
||||
arg_value = bound_args.arguments[param_name]
|
||||
|
||||
if arg_value is not None:
|
||||
prefix = PRIMITIVE_ID_PREFIXES[expected_prefix]
|
||||
if PRIMITIVE_ID_PATTERNS[prefix].match(arg_value) is None:
|
||||
raise LettaInvalidArgumentError(
|
||||
message=f"Invalid {expected_prefix.value} ID format: {arg_value}. Expected format: '{prefix}-<uuid4>'",
|
||||
argument_name=param_name,
|
||||
)
|
||||
|
||||
return function(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -286,7 +286,7 @@ async def test_turbopuffer_metadata_attributes(default_user, enable_turbopuffer)
|
||||
pytest.skip("No Turbopuffer API key available")
|
||||
|
||||
client = TurbopufferClient()
|
||||
archive_id = f"test-archive-{datetime.now().timestamp()}"
|
||||
archive_id = f"archive-{uuid.uuid4()}"
|
||||
|
||||
try:
|
||||
# Insert passages with various metadata
|
||||
@@ -391,7 +391,7 @@ async def test_hybrid_search_with_real_tpuf(default_user, enable_turbopuffer):
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
client = TurbopufferClient()
|
||||
archive_id = f"test-hybrid-{datetime.now().timestamp()}"
|
||||
archive_id = f"archive-{uuid.uuid4()}"
|
||||
org_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
@@ -497,7 +497,7 @@ async def test_tag_filtering_with_real_tpuf(default_user, enable_turbopuffer):
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
client = TurbopufferClient()
|
||||
archive_id = f"test-tags-{datetime.now().timestamp()}"
|
||||
archive_id = f"archive-{uuid.uuid4()}"
|
||||
org_id = str(uuid.uuid4())
|
||||
|
||||
try:
|
||||
@@ -628,7 +628,7 @@ async def test_temporal_filtering_with_real_tpuf(default_user, enable_turbopuffe
|
||||
client = TurbopufferClient()
|
||||
|
||||
# Create a unique archive ID for this test
|
||||
archive_id = f"test-temporal-{uuid.uuid4()}"
|
||||
archive_id = f"archive-{uuid.uuid4()}"
|
||||
|
||||
try:
|
||||
# Create passages with different timestamps
|
||||
|
||||
@@ -44,7 +44,7 @@ from letta.constants import (
|
||||
MULTI_AGENT_TOOLS,
|
||||
)
|
||||
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
|
||||
from letta.errors import LettaAgentNotFoundError
|
||||
from letta.errors import LettaAgentNotFoundError, LettaInvalidArgumentError
|
||||
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
||||
from letta.functions.mcp_client.types import MCPTool
|
||||
from letta.helpers import ToolRulesSolver
|
||||
@@ -233,7 +233,7 @@ async def test_update_job_auto_complete(server: SyncServer, default_user):
|
||||
async def test_get_job_not_found(server: SyncServer, default_user):
|
||||
"""Test fetching a non-existent job."""
|
||||
non_existent_job_id = "nonexistent-id"
|
||||
with pytest.raises(NoResultFound):
|
||||
with pytest.raises(LettaInvalidArgumentError):
|
||||
await server.job_manager.get_job_by_id_async(non_existent_job_id, actor=default_user)
|
||||
|
||||
|
||||
@@ -241,7 +241,7 @@ async def test_get_job_not_found(server: SyncServer, default_user):
|
||||
async def test_delete_job_not_found(server: SyncServer, default_user):
|
||||
"""Test deleting a non-existent job."""
|
||||
non_existent_job_id = "nonexistent-id"
|
||||
with pytest.raises(NoResultFound):
|
||||
with pytest.raises(LettaInvalidArgumentError):
|
||||
await server.job_manager.delete_job_by_id_async(non_existent_job_id, actor=default_user)
|
||||
|
||||
|
||||
@@ -412,9 +412,6 @@ async def test_e2e_job_callback(monkeypatch, server: SyncServer, default_user):
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_ttft(server: SyncServer, default_user):
|
||||
"""Test recording time to first token for a job."""
|
||||
@@ -478,9 +475,11 @@ async def test_record_timing_metrics_together(server: SyncServer, default_user):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_timing_invalid_job(server: SyncServer, default_user):
|
||||
"""Test recording timing metrics for non-existent job fails gracefully."""
|
||||
# Try to record TTFT for non-existent job - should not raise exception but log warning
|
||||
await server.job_manager.record_ttft("nonexistent_job_id", 1_000_000_000, default_user)
|
||||
"""Test recording timing metrics for non-existent job raises LettaInvalidArgumentError."""
|
||||
# Try to record TTFT for non-existent job - should raise LettaInvalidArgumentError
|
||||
with pytest.raises(LettaInvalidArgumentError):
|
||||
await server.job_manager.record_ttft("nonexistent_job_id", 1_000_000_000, default_user)
|
||||
|
||||
# Try to record response duration for non-existent job - should not raise exception but log warning
|
||||
await server.job_manager.record_response_duration("nonexistent_job_id", 2_000_000_000, default_user)
|
||||
# Try to record response duration for non-existent job - should raise LettaInvalidArgumentError
|
||||
with pytest.raises(LettaInvalidArgumentError):
|
||||
await server.job_manager.record_response_duration("nonexistent_job_id", 2_000_000_000, default_user)
|
||||
|
||||
@@ -44,7 +44,7 @@ from letta.constants import (
|
||||
MULTI_AGENT_TOOLS,
|
||||
)
|
||||
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
|
||||
from letta.errors import LettaAgentNotFoundError
|
||||
from letta.errors import LettaAgentNotFoundError, LettaInvalidArgumentError
|
||||
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
||||
from letta.functions.mcp_client.types import MCPTool
|
||||
from letta.helpers import ToolRulesSolver
|
||||
@@ -241,7 +241,7 @@ async def test_update_run_auto_complete(server: SyncServer, default_user, sarah_
|
||||
async def test_get_run_not_found(server: SyncServer, default_user):
|
||||
"""Test fetching a non-existent run."""
|
||||
non_existent_run_id = "nonexistent-id"
|
||||
with pytest.raises(NoResultFound):
|
||||
with pytest.raises(LettaInvalidArgumentError):
|
||||
await server.run_manager.get_run_by_id(non_existent_run_id, actor=default_user)
|
||||
|
||||
|
||||
@@ -249,7 +249,7 @@ async def test_get_run_not_found(server: SyncServer, default_user):
|
||||
async def test_delete_run_not_found(server: SyncServer, default_user):
|
||||
"""Test deleting a non-existent run."""
|
||||
non_existent_run_id = "nonexistent-id"
|
||||
with pytest.raises(NoResultFound):
|
||||
with pytest.raises(LettaInvalidArgumentError):
|
||||
await server.run_manager.delete_run(non_existent_run_id, actor=default_user)
|
||||
|
||||
|
||||
@@ -1268,7 +1268,7 @@ async def test_run_usage_stats_get_nonexistent_run(server: SyncServer, default_u
|
||||
"""Test getting usage statistics for a nonexistent run."""
|
||||
run_manager = server.run_manager
|
||||
|
||||
with pytest.raises(NoResultFound):
|
||||
with pytest.raises(LettaInvalidArgumentError):
|
||||
await run_manager.get_run_usage(run_id="nonexistent_run", actor=default_user)
|
||||
|
||||
|
||||
@@ -1307,7 +1307,7 @@ async def test_get_run_request_config_none(server: SyncServer, sarah_agent, defa
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_run_request_config_nonexistent_run(server: SyncServer, default_user):
|
||||
"""Test getting request config for a nonexistent run."""
|
||||
with pytest.raises(NoResultFound):
|
||||
with pytest.raises(LettaInvalidArgumentError):
|
||||
await server.run_manager.get_run_request_config("nonexistent_run", actor=default_user)
|
||||
|
||||
|
||||
@@ -1453,7 +1453,7 @@ async def test_run_metrics_num_steps_tracking(server: SyncServer, sarah_agent, d
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_metrics_not_found(server: SyncServer, default_user):
|
||||
"""Test getting metrics for non-existent run."""
|
||||
with pytest.raises(NoResultFound):
|
||||
with pytest.raises(LettaInvalidArgumentError):
|
||||
await server.run_manager.get_run_metrics_async(run_id="nonexistent_run", actor=default_user)
|
||||
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ from letta.constants import (
|
||||
MULTI_AGENT_TOOLS,
|
||||
)
|
||||
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
|
||||
from letta.errors import LettaAgentNotFoundError
|
||||
from letta.errors import LettaAgentNotFoundError, LettaInvalidArgumentError
|
||||
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
||||
from letta.functions.mcp_client.types import MCPTool
|
||||
from letta.helpers import ToolRulesSolver
|
||||
@@ -174,21 +174,21 @@ async def test_detach_source(server: SyncServer, sarah_agent, default_source, de
|
||||
async def test_attach_source_nonexistent_agent(server: SyncServer, default_source, default_user):
|
||||
"""Test attaching a source to a nonexistent agent."""
|
||||
with pytest.raises(NoResultFound):
|
||||
await server.agent_manager.attach_source_async(agent_id="nonexistent-agent-id", source_id=default_source.id, actor=default_user)
|
||||
await server.agent_manager.attach_source_async(agent_id=f"agent-{uuid.uuid4()}", source_id=default_source.id, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_source_nonexistent_source(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test attaching a nonexistent source to an agent."""
|
||||
with pytest.raises(NoResultFound):
|
||||
await server.agent_manager.attach_source_async(agent_id=sarah_agent.id, source_id="nonexistent-source-id", actor=default_user)
|
||||
await server.agent_manager.attach_source_async(agent_id=sarah_agent.id, source_id=f"source-{uuid.uuid4()}", actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detach_source_nonexistent_agent(server: SyncServer, default_source, default_user):
|
||||
"""Test detaching a source from a nonexistent agent."""
|
||||
with pytest.raises(LettaAgentNotFoundError):
|
||||
await server.agent_manager.detach_source_async(agent_id="nonexistent-agent-id", source_id=default_source.id, actor=default_user)
|
||||
await server.agent_manager.detach_source_async(agent_id=f"agent-{uuid.uuid4()}", source_id=default_source.id, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -234,7 +234,7 @@ async def test_list_attached_agents(server: SyncServer, sarah_agent, charles_age
|
||||
|
||||
async def test_list_attached_agents_nonexistent_source(server: SyncServer, default_user):
|
||||
"""Test listing agents for a nonexistent source."""
|
||||
with pytest.raises(NoResultFound):
|
||||
with pytest.raises(LettaInvalidArgumentError):
|
||||
await server.source_manager.list_attached_agents(source_id="nonexistent-source-id", actor=default_user)
|
||||
|
||||
|
||||
|
||||
@@ -220,7 +220,7 @@ async def test_bulk_detach_tools_idempotent(server: SyncServer, sarah_agent, pri
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_detach_tools_nonexistent_agent(server: SyncServer, print_tool, other_tool, default_user):
|
||||
"""Test bulk detaching tools from a nonexistent agent."""
|
||||
nonexistent_agent_id = "nonexistent-agent-id"
|
||||
nonexistent_agent_id = f"agent-{uuid.uuid4()}"
|
||||
tool_ids = [print_tool.id, other_tool.id]
|
||||
|
||||
with pytest.raises(LettaAgentNotFoundError):
|
||||
@@ -230,19 +230,19 @@ async def test_bulk_detach_tools_nonexistent_agent(server: SyncServer, print_too
|
||||
async def test_attach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user):
|
||||
"""Test attaching a tool to a nonexistent agent."""
|
||||
with pytest.raises(LettaAgentNotFoundError):
|
||||
await server.agent_manager.attach_tool_async(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user)
|
||||
await server.agent_manager.attach_tool_async(agent_id=f"agent-{uuid.uuid4()}", tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
|
||||
async def test_attach_tool_nonexistent_tool(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test attaching a nonexistent tool to an agent."""
|
||||
with pytest.raises(NoResultFound):
|
||||
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id="nonexistent-tool-id", actor=default_user)
|
||||
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=f"tool-{uuid.uuid4()}", actor=default_user)
|
||||
|
||||
|
||||
async def test_detach_tool_nonexistent_agent(server: SyncServer, print_tool, default_user):
|
||||
"""Test detaching a tool from a nonexistent agent."""
|
||||
with pytest.raises(LettaAgentNotFoundError):
|
||||
await server.agent_manager.detach_tool_async(agent_id="nonexistent-agent-id", tool_id=print_tool.id, actor=default_user)
|
||||
await server.agent_manager.detach_tool_async(agent_id=f"agent-{uuid.uuid4()}", tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -330,7 +330,7 @@ async def test_bulk_attach_tools_nonexistent_tool(server: SyncServer, sarah_agen
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_attach_tools_nonexistent_agent(server: SyncServer, print_tool, other_tool, default_user):
|
||||
"""Test bulk attaching tools to a nonexistent agent."""
|
||||
nonexistent_agent_id = "nonexistent-agent-id"
|
||||
nonexistent_agent_id = f"agent-{uuid.uuid4()}"
|
||||
tool_ids = [print_tool.id, other_tool.id]
|
||||
|
||||
with pytest.raises(LettaAgentNotFoundError):
|
||||
@@ -2187,7 +2187,8 @@ async def test_list_tools_with_corrupted_tool(server: SyncServer, default_user,
|
||||
from letta.orm.tool import Tool as ToolModel
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Create a tool with no json_schema (corrupted state)
|
||||
# Create a tool with corrupted ID format (bypassing validation)
|
||||
# This simulates a tool that somehow got corrupted in the database
|
||||
corrupted_tool = ToolModel(
|
||||
id=f"tool-corrupted-{uuid.uuid4()}",
|
||||
name="corrupted_tool",
|
||||
|
||||
Reference in New Issue
Block a user