feat: Allow per-agent tool execution environment variables (#509)
This commit is contained in:
@@ -0,0 +1,51 @@
|
||||
"""Add per agent environment variables
|
||||
|
||||
Revision ID: 400501b04bf0
|
||||
Revises: e78b4e82db30
|
||||
Create Date: 2025-01-04 20:45:28.024690
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "400501b04bf0"
|
||||
down_revision: Union[str, None] = "e78b4e82db30"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"agent_environment_variables",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("key", sa.String(), nullable=False),
|
||||
sa.Column("value", sa.String(), nullable=False),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.Column("agent_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("key", "agent_id", name="uix_key_agent"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("agent_environment_variables")
|
||||
# ### end Alembic commands ###
|
||||
@@ -4,9 +4,10 @@ import uuid
|
||||
|
||||
from letta import create_client
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ChatMemory
|
||||
from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxType
|
||||
from letta.schemas.sandbox_config import SandboxType
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.settings import tool_settings
|
||||
|
||||
|
||||
@@ -15,6 +15,11 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
||||
|
||||
# new schemas
|
||||
from letta.schemas.enums import JobStatus, MessageRole
|
||||
from letta.schemas.environment_variables import (
|
||||
SandboxEnvironmentVariable,
|
||||
SandboxEnvironmentVariableCreate,
|
||||
SandboxEnvironmentVariableUpdate,
|
||||
)
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.job import Job
|
||||
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
|
||||
@@ -25,16 +30,7 @@ from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||
from letta.schemas.openai.chat_completions import ToolCall
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.sandbox_config import (
|
||||
E2BSandboxConfig,
|
||||
LocalSandboxConfig,
|
||||
SandboxConfig,
|
||||
SandboxConfigCreate,
|
||||
SandboxConfigUpdate,
|
||||
SandboxEnvironmentVariable,
|
||||
SandboxEnvironmentVariableCreate,
|
||||
SandboxEnvironmentVariableUpdate,
|
||||
)
|
||||
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfig, SandboxConfigCreate, SandboxConfigUpdate
|
||||
from letta.schemas.source import Source, SourceCreate, SourceUpdate
|
||||
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
|
||||
from letta.schemas.tool_rule import BaseToolRule
|
||||
|
||||
@@ -31,7 +31,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
# agent generates its own id
|
||||
# TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase
|
||||
# TODO: Move this in this PR? at the very end?
|
||||
# TODO: Some still rely on the Pydantic object to do this
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"agent-{uuid.uuid4()}")
|
||||
|
||||
# Descriptor fields
|
||||
@@ -61,6 +61,13 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="agents")
|
||||
tool_exec_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship(
|
||||
"AgentEnvironmentVariable",
|
||||
back_populates="agent",
|
||||
cascade="all, delete-orphan",
|
||||
lazy="selectin",
|
||||
doc="Environment variables associated with this agent.",
|
||||
)
|
||||
tools: Mapped[List["Tool"]] = relationship("Tool", secondary="tools_agents", lazy="selectin", passive_deletes=True)
|
||||
sources: Mapped[List["Source"]] = relationship("Source", secondary="sources_agents", lazy="selectin")
|
||||
core_memory: Mapped[List["Block"]] = relationship("Block", secondary="blocks_agents", lazy="selectin")
|
||||
@@ -119,5 +126,6 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
"last_updated_by_id": self.last_updated_by_id,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
"tool_exec_environment_variables": self.tool_exec_environment_variables,
|
||||
}
|
||||
return self.__pydantic_model__(**state)
|
||||
|
||||
@@ -9,6 +9,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from letta.orm.agent import Agent
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable
|
||||
from letta.orm.tool import Tool
|
||||
from letta.orm.user import User
|
||||
|
||||
@@ -33,6 +34,9 @@ class Organization(SqlalchemyBase):
|
||||
sandbox_environment_variables: Mapped[List["SandboxEnvironmentVariable"]] = relationship(
|
||||
"SandboxEnvironmentVariable", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
agent_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship(
|
||||
"AgentEnvironmentVariable", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# relationships
|
||||
agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import JSON
|
||||
@@ -5,13 +6,14 @@ from sqlalchemy import Enum as SqlEnum
|
||||
from sqlalchemy import String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin, SandboxConfigMixin
|
||||
from letta.orm.mixins import AgentMixin, OrganizationMixin, SandboxConfigMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariable as PydanticSandboxEnvironmentVariable
|
||||
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
|
||||
from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticSandboxEnvironmentVariable
|
||||
from letta.schemas.sandbox_config import SandboxType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.agent import Agent
|
||||
from letta.orm.organization import Organization
|
||||
|
||||
|
||||
@@ -52,3 +54,21 @@ class SandboxEnvironmentVariable(SqlalchemyBase, OrganizationMixin, SandboxConfi
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="sandbox_environment_variables")
|
||||
sandbox_config: Mapped["SandboxConfig"] = relationship("SandboxConfig", back_populates="sandbox_environment_variables")
|
||||
|
||||
|
||||
class AgentEnvironmentVariable(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
"""ORM model for environment variables associated with agents."""
|
||||
|
||||
__tablename__ = "agent_environment_variables"
|
||||
# We cannot have duplicate key names for the same agent, the env var would get overwritten
|
||||
__table_args__ = (UniqueConstraint("key", "agent_id", name="uix_key_agent"),)
|
||||
|
||||
# agent_env_var generates its own id
|
||||
# TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"agent-env-{uuid.uuid4()}")
|
||||
key: Mapped[str] = mapped_column(String, nullable=False, doc="The name of the environment variable.")
|
||||
value: Mapped[str] = mapped_column(String, nullable=False, doc="The value of the environment variable.")
|
||||
description: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="An optional description of the environment variable.")
|
||||
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="agent_environment_variables")
|
||||
agent: Mapped[List["Agent"]] = relationship("Agent", back_populates="tool_exec_environment_variables")
|
||||
|
||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.environment_variables import AgentEnvironmentVariable
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import Memory
|
||||
@@ -78,6 +79,9 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
|
||||
tools: List[Tool] = Field(..., description="The tools used by the agent.")
|
||||
sources: List[Source] = Field(..., description="The sources used by the agent.")
|
||||
tags: List[str] = Field(..., description="The tags associated with the agent.")
|
||||
tool_exec_environment_variables: List[AgentEnvironmentVariable] = Field(
|
||||
..., description="The environment variables for tool execution specific to this agent."
|
||||
)
|
||||
|
||||
|
||||
class CreateAgent(BaseModel, validate_assignment=True): #
|
||||
@@ -120,6 +124,9 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
||||
embedding_chunk_size: Optional[int] = Field(DEFAULT_EMBEDDING_CHUNK_SIZE, description="The embedding chunk size used by the agent.")
|
||||
from_template: Optional[str] = Field(None, description="The template id used to configure the agent")
|
||||
project_id: Optional[str] = Field(None, description="The project id that the agent will be associated with.")
|
||||
tool_exec_environment_variables: Optional[Dict[str, str]] = Field(
|
||||
None, description="The environment variables for tool execution specific to this agent."
|
||||
)
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
@@ -184,6 +191,9 @@ class UpdateAgent(BaseModel):
|
||||
message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
|
||||
description: Optional[str] = Field(None, description="The description of the agent.")
|
||||
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
|
||||
tool_exec_environment_variables: Optional[Dict[str, str]] = Field(
|
||||
None, description="The environment variables for tool execution specific to this agent."
|
||||
)
|
||||
|
||||
class Config:
|
||||
extra = "ignore" # Ignores extra fields
|
||||
|
||||
62
letta/schemas/environment_variables.py
Normal file
62
letta/schemas/environment_variables.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.letta_base import LettaBase, OrmMetadataBase
|
||||
|
||||
|
||||
# Base Environment Variable
|
||||
class EnvironmentVariableBase(OrmMetadataBase):
|
||||
id: str = Field(..., description="The unique identifier for the environment variable.")
|
||||
key: str = Field(..., description="The name of the environment variable.")
|
||||
value: str = Field(..., description="The value of the environment variable.")
|
||||
description: Optional[str] = Field(None, description="An optional description of the environment variable.")
|
||||
organization_id: Optional[str] = Field(None, description="The ID of the organization this environment variable belongs to.")
|
||||
|
||||
|
||||
class EnvironmentVariableCreateBase(LettaBase):
|
||||
key: str = Field(..., description="The name of the environment variable.")
|
||||
value: str = Field(..., description="The value of the environment variable.")
|
||||
description: Optional[str] = Field(None, description="An optional description of the environment variable.")
|
||||
|
||||
|
||||
class EnvironmentVariableUpdateBase(LettaBase):
|
||||
key: Optional[str] = Field(None, description="The name of the environment variable.")
|
||||
value: Optional[str] = Field(None, description="The value of the environment variable.")
|
||||
description: Optional[str] = Field(None, description="An optional description of the environment variable.")
|
||||
|
||||
|
||||
# Sandbox-Specific Environment Variable
|
||||
class SandboxEnvironmentVariableBase(EnvironmentVariableBase):
|
||||
__id_prefix__ = "sandbox-env"
|
||||
sandbox_config_id: str = Field(..., description="The ID of the sandbox config this environment variable belongs to.")
|
||||
|
||||
|
||||
class SandboxEnvironmentVariable(SandboxEnvironmentVariableBase):
|
||||
id: str = SandboxEnvironmentVariableBase.generate_id_field()
|
||||
|
||||
|
||||
class SandboxEnvironmentVariableCreate(EnvironmentVariableCreateBase):
|
||||
pass
|
||||
|
||||
|
||||
class SandboxEnvironmentVariableUpdate(EnvironmentVariableUpdateBase):
|
||||
pass
|
||||
|
||||
|
||||
# Agent-Specific Environment Variable
|
||||
class AgentEnvironmentVariableBase(EnvironmentVariableBase):
|
||||
__id_prefix__ = "agent-env"
|
||||
agent_id: str = Field(..., description="The ID of the agent this environment variable belongs to.")
|
||||
|
||||
|
||||
class AgentEnvironmentVariable(AgentEnvironmentVariableBase):
|
||||
id: str = AgentEnvironmentVariableBase.generate_id_field()
|
||||
|
||||
|
||||
class AgentEnvironmentVariableCreate(EnvironmentVariableCreateBase):
|
||||
pass
|
||||
|
||||
|
||||
class AgentEnvironmentVariableUpdate(EnvironmentVariableUpdateBase):
|
||||
pass
|
||||
@@ -102,31 +102,3 @@ class SandboxConfigUpdate(LettaBase):
|
||||
"""Pydantic model for updating SandboxConfig fields."""
|
||||
|
||||
config: Union[LocalSandboxConfig, E2BSandboxConfig] = Field(None, description="The JSON configuration data for the sandbox.")
|
||||
|
||||
|
||||
# Environment Variable
|
||||
class SandboxEnvironmentVariableBase(OrmMetadataBase):
|
||||
__id_prefix__ = "sandbox-env"
|
||||
|
||||
|
||||
class SandboxEnvironmentVariable(SandboxEnvironmentVariableBase):
|
||||
id: str = SandboxEnvironmentVariableBase.generate_id_field()
|
||||
key: str = Field(..., description="The name of the environment variable.")
|
||||
value: str = Field(..., description="The value of the environment variable.")
|
||||
description: Optional[str] = Field(None, description="An optional description of the environment variable.")
|
||||
sandbox_config_id: str = Field(..., description="The ID of the sandbox config this environment variable belongs to.")
|
||||
organization_id: Optional[str] = Field(None, description="The ID of the organization this environment variable belongs to.")
|
||||
|
||||
|
||||
class SandboxEnvironmentVariableCreate(LettaBase):
|
||||
key: str = Field(..., description="The name of the environment variable.")
|
||||
value: str = Field(..., description="The value of the environment variable.")
|
||||
description: Optional[str] = Field(None, description="An optional description of the environment variable.")
|
||||
|
||||
|
||||
class SandboxEnvironmentVariableUpdate(LettaBase):
|
||||
"""Pydantic model for updating SandboxEnvironmentVariable fields."""
|
||||
|
||||
key: Optional[str] = Field(None, description="The name of the environment variable.")
|
||||
value: Optional[str] = Field(None, description="The value of the environment variable.")
|
||||
description: Optional[str] = Field(None, description="An optional description of the environment variable.")
|
||||
|
||||
@@ -2,10 +2,10 @@ from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariable as PydanticEnvVar
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
|
||||
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
|
||||
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate
|
||||
from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar
|
||||
from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate, SandboxType
|
||||
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate, SandboxType
|
||||
from letta.server.rest_api.utils import get_letta_server, get_user_id
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
||||
|
||||
# openai schemas
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
|
||||
from letta.schemas.job import Job, JobUpdate
|
||||
from letta.schemas.letta_message import LettaMessage, ToolReturnMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
@@ -53,7 +54,7 @@ from letta.schemas.memory import ArchivalMemorySummary, ContextWindowOverview, M
|
||||
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxType
|
||||
from letta.schemas.sandbox_config import SandboxType
|
||||
from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
@@ -14,6 +14,7 @@ from letta.orm import Source as SourceModel
|
||||
from letta.orm import SourcePassage, SourcesAgents
|
||||
from letta.orm import Tool as ToolModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable as AgentEnvironmentVariableModel
|
||||
from letta.orm.sqlite_functions import adapt_array
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
|
||||
@@ -116,6 +117,14 @@ class AgentManager:
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# If there are provided environment variables, add them in
|
||||
if agent_create.tool_exec_environment_variables:
|
||||
agent_state = self._set_environment_variables(
|
||||
agent_id=agent_state.id,
|
||||
env_vars=agent_create.tool_exec_environment_variables,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# TODO: See if we can merge this into the above SQL create call for performance reasons
|
||||
# Generate a sequence of initial messages to put in the buffer
|
||||
init_messages = initialize_message_sequence(
|
||||
@@ -192,6 +201,14 @@ class AgentManager:
|
||||
def update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState:
|
||||
agent_state = self._update_agent(agent_id=agent_id, agent_update=agent_update, actor=actor)
|
||||
|
||||
# If there are provided environment variables, add them in
|
||||
if agent_update.tool_exec_environment_variables:
|
||||
agent_state = self._set_environment_variables(
|
||||
agent_id=agent_state.id,
|
||||
env_vars=agent_update.tool_exec_environment_variables,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Rebuild the system prompt if it's different
|
||||
if agent_update.system and agent_update.system != agent_state.system:
|
||||
agent_state = self.rebuild_system_prompt(agent_id=agent_state.id, actor=actor, force=True, update_timestamp=False)
|
||||
@@ -296,6 +313,43 @@ class AgentManager:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
agent.hard_delete(session)
|
||||
|
||||
# ======================================================================================================================
|
||||
# Per Agent Environment Variable Management
|
||||
# ======================================================================================================================
|
||||
@enforce_types
|
||||
def _set_environment_variables(
|
||||
self,
|
||||
agent_id: str,
|
||||
env_vars: Dict[str, str],
|
||||
actor: PydanticUser,
|
||||
) -> PydanticAgentState:
|
||||
"""
|
||||
Adds or replaces the environment variables for the specified agent.
|
||||
|
||||
Args:
|
||||
agent_id: The agent id.
|
||||
env_vars: A dictionary of environment variable key-value pairs.
|
||||
actor: The user performing the action.
|
||||
|
||||
Returns:
|
||||
PydanticAgentState: The updated agent as a Pydantic model.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# Retrieve the agent
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# Replace the environment variables
|
||||
agent.tool_exec_environment_variables = [
|
||||
AgentEnvironmentVariableModel(key=key, value=value, agent_id=agent_id, organization_id=actor.organization_id)
|
||||
for key, value in env_vars.items()
|
||||
]
|
||||
|
||||
# Update the agent in the database
|
||||
agent.update(session, actor=actor)
|
||||
|
||||
# Return the updated agent state
|
||||
return agent.to_pydantic()
|
||||
|
||||
# ======================================================================================================================
|
||||
# In Context Messages Management
|
||||
# ======================================================================================================================
|
||||
|
||||
@@ -5,11 +5,11 @@ from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.sandbox_config import SandboxConfig as SandboxConfigModel
|
||||
from letta.orm.sandbox_config import SandboxEnvironmentVariable as SandboxEnvVarModel
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariable as PydanticEnvVar
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
|
||||
from letta.schemas.sandbox_config import LocalSandboxConfig
|
||||
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
|
||||
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate
|
||||
from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar
|
||||
from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate, SandboxType
|
||||
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate, SandboxType
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types, printd
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from letta.functions.functions import parse_source_code
|
||||
from letta.functions.schema_generator import generate_schema
|
||||
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
|
||||
|
||||
def cleanup(client: Union[LocalClient, RESTClient], agent_uuid: str):
|
||||
@@ -27,12 +28,19 @@ def create_tool_from_func(func: callable):
|
||||
)
|
||||
|
||||
|
||||
def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, UpdateAgent]):
|
||||
def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, UpdateAgent], actor: PydanticUser):
|
||||
# Assert scalar fields
|
||||
assert agent.system == request.system, f"System prompt mismatch: {agent.system} != {request.system}"
|
||||
assert agent.description == request.description, f"Description mismatch: {agent.description} != {request.description}"
|
||||
assert agent.metadata_ == request.metadata_, f"Metadata mismatch: {agent.metadata_} != {request.metadata_}"
|
||||
|
||||
# Assert agent env vars
|
||||
if hasattr(request, "tool_exec_environment_variables"):
|
||||
for agent_env_var in agent.tool_exec_environment_variables:
|
||||
assert agent_env_var.key in request.tool_exec_environment_variables
|
||||
assert request.tool_exec_environment_variables[agent_env_var.key] == agent_env_var.value
|
||||
assert agent_env_var.organization_id == actor.organization_id
|
||||
|
||||
# Assert agent type
|
||||
if hasattr(request, "agent_type"):
|
||||
assert agent.agent_type == request.agent_type, f"Agent type mismatch: {agent.agent_type} != {request.agent_type}"
|
||||
|
||||
@@ -9,20 +9,14 @@ from sqlalchemy import delete
|
||||
|
||||
from letta import create_client
|
||||
from letta.functions.function_sets.base import core_memory_append, core_memory_replace
|
||||
from letta.orm import SandboxConfig, SandboxEnvironmentVariable
|
||||
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ChatMemory
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.sandbox_config import (
|
||||
E2BSandboxConfig,
|
||||
LocalSandboxConfig,
|
||||
SandboxConfigCreate,
|
||||
SandboxConfigUpdate,
|
||||
SandboxEnvironmentVariableCreate,
|
||||
SandboxType,
|
||||
)
|
||||
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate, SandboxType
|
||||
from letta.schemas.tool import Tool, ToolCreate
|
||||
from letta.schemas.user import User
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
|
||||
@@ -35,6 +35,7 @@ from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.block import BlockUpdate, CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import JobStatus, MessageRole
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
|
||||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from letta.schemas.job import Job as PydanticJob
|
||||
from letta.schemas.job import JobUpdate
|
||||
@@ -43,15 +44,7 @@ from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageCreate, MessageUpdate
|
||||
from letta.schemas.organization import Organization as PydanticOrganization
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.sandbox_config import (
|
||||
E2BSandboxConfig,
|
||||
LocalSandboxConfig,
|
||||
SandboxConfigCreate,
|
||||
SandboxConfigUpdate,
|
||||
SandboxEnvironmentVariableCreate,
|
||||
SandboxEnvironmentVariableUpdate,
|
||||
SandboxType,
|
||||
)
|
||||
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate, SandboxType
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.source import SourceUpdate
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
@@ -413,6 +406,7 @@ def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_too
|
||||
metadata_={"test_key": "test_value"},
|
||||
tool_rules=[InitToolRule(tool_name=print_tool.name)],
|
||||
initial_message_sequence=[MessageCreate(role=MessageRole.user, text="hello world")],
|
||||
tool_exec_environment_variables={"test_env_var_key_a": "test_env_var_value_a", "test_env_var_key_b": "test_env_var_value_b"},
|
||||
)
|
||||
created_agent = server.agent_manager.create_agent(
|
||||
create_agent_request,
|
||||
@@ -482,20 +476,20 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent):
|
||||
def test_create_get_list_agent(server: SyncServer, comprehensive_test_agent_fixture, default_user):
|
||||
# Test agent creation
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
comprehensive_agent_checks(created_agent, create_agent_request)
|
||||
comprehensive_agent_checks(created_agent, create_agent_request, actor=default_user)
|
||||
|
||||
# Test get agent
|
||||
get_agent = server.agent_manager.get_agent_by_id(agent_id=created_agent.id, actor=default_user)
|
||||
comprehensive_agent_checks(get_agent, create_agent_request)
|
||||
comprehensive_agent_checks(get_agent, create_agent_request, actor=default_user)
|
||||
|
||||
# Test get agent name
|
||||
get_agent_name = server.agent_manager.get_agent_by_name(agent_name=created_agent.name, actor=default_user)
|
||||
comprehensive_agent_checks(get_agent_name, create_agent_request)
|
||||
comprehensive_agent_checks(get_agent_name, create_agent_request, actor=default_user)
|
||||
|
||||
# Test list agent
|
||||
list_agents = server.agent_manager.list_agents(actor=default_user)
|
||||
assert len(list_agents) == 1
|
||||
comprehensive_agent_checks(list_agents[0], create_agent_request)
|
||||
comprehensive_agent_checks(list_agents[0], create_agent_request, actor=default_user)
|
||||
|
||||
# Test deleting the agent
|
||||
server.agent_manager.delete_agent(get_agent.id, default_user)
|
||||
@@ -566,10 +560,11 @@ def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, othe
|
||||
embedding_config=EmbeddingConfig.default_config(model_name="letta"),
|
||||
message_ids=["10", "20"],
|
||||
metadata_={"train_key": "train_value"},
|
||||
tool_exec_environment_variables={"new_tool_exec_key": "new_tool_exec_value"},
|
||||
)
|
||||
|
||||
updated_agent = server.agent_manager.update_agent(agent.id, update_agent_request, actor=default_user)
|
||||
comprehensive_agent_checks(updated_agent, update_agent_request)
|
||||
comprehensive_agent_checks(updated_agent, update_agent_request, actor=default_user)
|
||||
assert updated_agent.message_ids == update_agent_request.message_ids
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user