diff --git a/alembic/versions/400501b04bf0_add_per_agent_environment_variables.py b/alembic/versions/400501b04bf0_add_per_agent_environment_variables.py new file mode 100644 index 00000000..584e1e4c --- /dev/null +++ b/alembic/versions/400501b04bf0_add_per_agent_environment_variables.py @@ -0,0 +1,51 @@ +"""Add per agent environment variables + +Revision ID: 400501b04bf0 +Revises: e78b4e82db30 +Create Date: 2025-01-04 20:45:28.024690 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "400501b04bf0" +down_revision: Union[str, None] = "e78b4e82db30" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "agent_environment_variables", + sa.Column("id", sa.String(), nullable=False), + sa.Column("key", sa.String(), nullable=False), + sa.Column("value", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.Column("agent_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("key", "agent_id", name="uix_key_agent"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("agent_environment_variables") + # ### end Alembic commands ### diff --git a/examples/composio_tool_usage.py b/examples/composio_tool_usage.py index c3c81895..fc6c3c12 100644 --- a/examples/composio_tool_usage.py +++ b/examples/composio_tool_usage.py @@ -4,9 +4,10 @@ import uuid from letta import create_client from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ChatMemory -from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxType +from letta.schemas.sandbox_config import SandboxType from letta.services.sandbox_config_manager import SandboxConfigManager from letta.settings import tool_settings diff --git a/letta/client/client.py b/letta/client/client.py index 9931628c..ae75e9eb 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -15,6 +15,11 @@ from letta.schemas.embedding_config import EmbeddingConfig # new schemas from letta.schemas.enums import JobStatus, MessageRole +from letta.schemas.environment_variables import ( + SandboxEnvironmentVariable, + SandboxEnvironmentVariableCreate, + SandboxEnvironmentVariableUpdate, +) from letta.schemas.file import FileMetadata from letta.schemas.job import Job from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest @@ -25,16 +30,7 @@ from letta.schemas.message import Message, MessageCreate, MessageUpdate from letta.schemas.openai.chat_completions import ToolCall from letta.schemas.organization import Organization from letta.schemas.passage import Passage -from letta.schemas.sandbox_config import ( - E2BSandboxConfig, - LocalSandboxConfig, - SandboxConfig, - SandboxConfigCreate, - SandboxConfigUpdate, - SandboxEnvironmentVariable, - SandboxEnvironmentVariableCreate, - SandboxEnvironmentVariableUpdate, -) +from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfig, SandboxConfigCreate, SandboxConfigUpdate from letta.schemas.source import Source, SourceCreate, SourceUpdate from letta.schemas.tool import Tool, ToolCreate, ToolUpdate from letta.schemas.tool_rule import BaseToolRule diff --git a/letta/orm/agent.py b/letta/orm/agent.py index 353d4fe7..271527c6 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -31,7 +31,7 @@ class Agent(SqlalchemyBase, OrganizationMixin): # agent generates its own id # TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase - # TODO: Move this in this PR? at the very end? + # TODO: Some still rely on the Pydantic object to do this id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"agent-{uuid.uuid4()}") # Descriptor fields @@ -61,6 +61,13 @@ class Agent(SqlalchemyBase, OrganizationMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="agents") + tool_exec_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship( + "AgentEnvironmentVariable", + back_populates="agent", + cascade="all, delete-orphan", + lazy="selectin", + doc="Environment variables associated with this agent.", + ) tools: Mapped[List["Tool"]] = relationship("Tool", secondary="tools_agents", lazy="selectin", passive_deletes=True) sources: Mapped[List["Source"]] = relationship("Source", secondary="sources_agents", lazy="selectin") core_memory: Mapped[List["Block"]] = relationship("Block", secondary="blocks_agents", lazy="selectin") @@ -119,5 +126,6 @@ class Agent(SqlalchemyBase, OrganizationMixin): "last_updated_by_id": self.last_updated_by_id, "created_at": self.created_at, "updated_at": self.updated_at, + "tool_exec_environment_variables": self.tool_exec_environment_variables, } return self.__pydantic_model__(**state) diff --git a/letta/orm/organization.py b/letta/orm/organization.py index 335a15d0..486cfcc4 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from letta.orm.agent import Agent from letta.orm.file import FileMetadata + from letta.orm.sandbox_config import AgentEnvironmentVariable from letta.orm.tool import Tool from letta.orm.user import User @@ -33,6 +34,9 @@ class Organization(SqlalchemyBase): sandbox_environment_variables: Mapped[List["SandboxEnvironmentVariable"]] = relationship( "SandboxEnvironmentVariable", back_populates="organization", cascade="all, delete-orphan" ) + agent_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship( + "AgentEnvironmentVariable", back_populates="organization", cascade="all, delete-orphan" + ) # relationships agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan") diff --git a/letta/orm/sandbox_config.py b/letta/orm/sandbox_config.py index 9058657f..164814c5 100644 --- a/letta/orm/sandbox_config.py +++ b/letta/orm/sandbox_config.py @@ -1,3 +1,4 @@ +import uuid from typing import TYPE_CHECKING, Dict, List, Optional from sqlalchemy import JSON @@ -5,13 +6,14 @@ from sqlalchemy import Enum as SqlEnum from sqlalchemy import String, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, relationship -from letta.orm.mixins import OrganizationMixin, SandboxConfigMixin +from letta.orm.mixins import AgentMixin, OrganizationMixin, SandboxConfigMixin from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.environment_variables import SandboxEnvironmentVariable as PydanticSandboxEnvironmentVariable from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig -from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticSandboxEnvironmentVariable from letta.schemas.sandbox_config import SandboxType if TYPE_CHECKING: + from letta.orm.agent import Agent from letta.orm.organization import Organization @@ -52,3 +54,21 @@ class SandboxEnvironmentVariable(SqlalchemyBase, OrganizationMixin, SandboxConfi # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="sandbox_environment_variables") sandbox_config: Mapped["SandboxConfig"] = relationship("SandboxConfig", back_populates="sandbox_environment_variables") + + +class AgentEnvironmentVariable(SqlalchemyBase, OrganizationMixin, AgentMixin): + """ORM model for environment variables associated with agents.""" + + __tablename__ = "agent_environment_variables" + # We cannot have duplicate key names for the same agent, the env var would get overwritten + __table_args__ = (UniqueConstraint("key", "agent_id", name="uix_key_agent"),) + + # agent_env_var generates its own id + # TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase + id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"agent-env-{uuid.uuid4()}") + key: Mapped[str] = mapped_column(String, nullable=False, doc="The name of the environment variable.") + value: Mapped[str] = mapped_column(String, nullable=False, doc="The value of the environment variable.") + description: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="An optional description of the environment variable.") + + organization: Mapped["Organization"] = relationship("Organization", back_populates="agent_environment_variables") + agent: Mapped[List["Agent"]] = relationship("Agent", back_populates="tool_exec_environment_variables") diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 56b2168e..57268645 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -6,6 +6,7 @@ from pydantic import BaseModel, Field, field_validator from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE from letta.schemas.block import CreateBlock from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.environment_variables import AgentEnvironmentVariable from letta.schemas.letta_base import OrmMetadataBase from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import Memory @@ -78,6 +79,9 @@ class AgentState(OrmMetadataBase, validate_assignment=True): tools: List[Tool] = Field(..., description="The tools used by the agent.") sources: List[Source] = Field(..., description="The sources used by the agent.") tags: List[str] = Field(..., description="The tags associated with the agent.") + tool_exec_environment_variables: List[AgentEnvironmentVariable] = Field( + ..., description="The environment variables for tool execution specific to this agent." + ) class CreateAgent(BaseModel, validate_assignment=True): # @@ -120,6 +124,9 @@ class CreateAgent(BaseModel, validate_assignment=True): # embedding_chunk_size: Optional[int] = Field(DEFAULT_EMBEDDING_CHUNK_SIZE, description="The embedding chunk size used by the agent.") from_template: Optional[str] = Field(None, description="The template id used to configure the agent") project_id: Optional[str] = Field(None, description="The project id that the agent will be associated with.") + tool_exec_environment_variables: Optional[Dict[str, str]] = Field( + None, description="The environment variables for tool execution specific to this agent." + ) @field_validator("name") @classmethod @@ -184,6 +191,9 @@ class UpdateAgent(BaseModel): message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.") description: Optional[str] = Field(None, description="The description of the agent.") metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_") + tool_exec_environment_variables: Optional[Dict[str, str]] = Field( + None, description="The environment variables for tool execution specific to this agent." + ) class Config: extra = "ignore" # Ignores extra fields diff --git a/letta/schemas/environment_variables.py b/letta/schemas/environment_variables.py new file mode 100644 index 00000000..9f482c1c --- /dev/null +++ b/letta/schemas/environment_variables.py @@ -0,0 +1,62 @@ +from typing import Optional + +from pydantic import Field + +from letta.schemas.letta_base import LettaBase, OrmMetadataBase + + +# Base Environment Variable +class EnvironmentVariableBase(OrmMetadataBase): + id: str = Field(..., description="The unique identifier for the environment variable.") + key: str = Field(..., description="The name of the environment variable.") + value: str = Field(..., description="The value of the environment variable.") + description: Optional[str] = Field(None, description="An optional description of the environment variable.") + organization_id: Optional[str] = Field(None, description="The ID of the organization this environment variable belongs to.") + + +class EnvironmentVariableCreateBase(LettaBase): + key: str = Field(..., description="The name of the environment variable.") + value: str = Field(..., description="The value of the environment variable.") + description: Optional[str] = Field(None, description="An optional description of the environment variable.") + + +class EnvironmentVariableUpdateBase(LettaBase): + key: Optional[str] = Field(None, description="The name of the environment variable.") + value: Optional[str] = Field(None, description="The value of the environment variable.") + description: Optional[str] = Field(None, description="An optional description of the environment variable.") + + +# Sandbox-Specific Environment Variable +class SandboxEnvironmentVariableBase(EnvironmentVariableBase): + __id_prefix__ = "sandbox-env" + sandbox_config_id: str = Field(..., description="The ID of the sandbox config this environment variable belongs to.") + + +class SandboxEnvironmentVariable(SandboxEnvironmentVariableBase): + id: str = SandboxEnvironmentVariableBase.generate_id_field() + + +class SandboxEnvironmentVariableCreate(EnvironmentVariableCreateBase): + pass + + +class SandboxEnvironmentVariableUpdate(EnvironmentVariableUpdateBase): + pass + + +# Agent-Specific Environment Variable +class AgentEnvironmentVariableBase(EnvironmentVariableBase): + __id_prefix__ = "agent-env" + agent_id: str = Field(..., description="The ID of the agent this environment variable belongs to.") + + +class AgentEnvironmentVariable(AgentEnvironmentVariableBase): + id: str = AgentEnvironmentVariableBase.generate_id_field() + + +class AgentEnvironmentVariableCreate(EnvironmentVariableCreateBase): + pass + + +class AgentEnvironmentVariableUpdate(EnvironmentVariableUpdateBase): + pass diff --git a/letta/schemas/sandbox_config.py b/letta/schemas/sandbox_config.py index f86233fa..bc5698e9 100644 --- a/letta/schemas/sandbox_config.py +++ b/letta/schemas/sandbox_config.py @@ -102,31 +102,3 @@ class SandboxConfigUpdate(LettaBase): """Pydantic model for updating SandboxConfig fields.""" config: Union[LocalSandboxConfig, E2BSandboxConfig] = Field(None, description="The JSON configuration data for the sandbox.") - - -# Environment Variable -class SandboxEnvironmentVariableBase(OrmMetadataBase): - __id_prefix__ = "sandbox-env" - - -class SandboxEnvironmentVariable(SandboxEnvironmentVariableBase): - id: str = SandboxEnvironmentVariableBase.generate_id_field() - key: str = Field(..., description="The name of the environment variable.") - value: str = Field(..., description="The value of the environment variable.") - description: Optional[str] = Field(None, description="An optional description of the environment variable.") - sandbox_config_id: str = Field(..., description="The ID of the sandbox config this environment variable belongs to.") - organization_id: Optional[str] = Field(None, description="The ID of the organization this environment variable belongs to.") - - -class SandboxEnvironmentVariableCreate(LettaBase): - key: str = Field(..., description="The name of the environment variable.") - value: str = Field(..., description="The value of the environment variable.") - description: Optional[str] = Field(None, description="An optional description of the environment variable.") - - -class SandboxEnvironmentVariableUpdate(LettaBase): - """Pydantic model for updating SandboxEnvironmentVariable fields.""" - - key: Optional[str] = Field(None, description="The name of the environment variable.") - value: Optional[str] = Field(None, description="The value of the environment variable.") - description: Optional[str] = Field(None, description="An optional description of the environment variable.") diff --git a/letta/server/rest_api/routers/v1/sandbox_configs.py b/letta/server/rest_api/routers/v1/sandbox_configs.py index 436d9b8e..d5c16c04 100644 --- a/letta/server/rest_api/routers/v1/sandbox_configs.py +++ b/letta/server/rest_api/routers/v1/sandbox_configs.py @@ -2,10 +2,10 @@ from typing import List, Optional from fastapi import APIRouter, Depends, Query +from letta.schemas.environment_variables import SandboxEnvironmentVariable as PydanticEnvVar +from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig -from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate -from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar -from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate, SandboxType +from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate, SandboxType from letta.server.rest_api.utils import get_letta_server, get_user_id from letta.server.server import SyncServer diff --git a/letta/server/server.py b/letta/server/server.py index a619463a..9d6dd859 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -46,6 +46,7 @@ from letta.schemas.embedding_config import EmbeddingConfig # openai schemas from letta.schemas.enums import JobStatus +from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate from letta.schemas.job import Job, JobUpdate from letta.schemas.letta_message import LettaMessage, ToolReturnMessage from letta.schemas.llm_config import LLMConfig @@ -53,7 +54,7 @@ from letta.schemas.memory import ArchivalMemorySummary, ContextWindowOverview, M from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate from letta.schemas.organization import Organization from letta.schemas.passage import Passage -from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxType +from letta.schemas.sandbox_config import SandboxType from letta.schemas.source import Source from letta.schemas.tool import Tool from letta.schemas.usage import LettaUsageStatistics diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index a51a5fac..5c92f59e 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -14,6 +14,7 @@ from letta.orm import Source as SourceModel from letta.orm import SourcePassage, SourcesAgents from letta.orm import Tool as ToolModel from letta.orm.errors import NoResultFound +from letta.orm.sandbox_config import AgentEnvironmentVariable as AgentEnvironmentVariableModel from letta.orm.sqlite_functions import adapt_array from letta.schemas.agent import AgentState as PydanticAgentState from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent @@ -116,6 +117,14 @@ class AgentManager: actor=actor, ) + # If there are provided environment variables, add them in + if agent_create.tool_exec_environment_variables: + agent_state = self._set_environment_variables( + agent_id=agent_state.id, + env_vars=agent_create.tool_exec_environment_variables, + actor=actor, + ) + # TODO: See if we can merge this into the above SQL create call for performance reasons # Generate a sequence of initial messages to put in the buffer init_messages = initialize_message_sequence( @@ -192,6 +201,14 @@ class AgentManager: def update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState: agent_state = self._update_agent(agent_id=agent_id, agent_update=agent_update, actor=actor) + # If there are provided environment variables, add them in + if agent_update.tool_exec_environment_variables: + agent_state = self._set_environment_variables( + agent_id=agent_state.id, + env_vars=agent_update.tool_exec_environment_variables, + actor=actor, + ) + # Rebuild the system prompt if it's different if agent_update.system and agent_update.system != agent_state.system: agent_state = self.rebuild_system_prompt(agent_id=agent_state.id, actor=actor, force=True, update_timestamp=False) @@ -296,6 +313,43 @@ class AgentManager: agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) agent.hard_delete(session) + # ====================================================================================================================== + # Per Agent Environment Variable Management + # ====================================================================================================================== + @enforce_types + def _set_environment_variables( + self, + agent_id: str, + env_vars: Dict[str, str], + actor: PydanticUser, + ) -> PydanticAgentState: + """ + Adds or replaces the environment variables for the specified agent. + + Args: + agent_id: The agent id. + env_vars: A dictionary of environment variable key-value pairs. + actor: The user performing the action. + + Returns: + PydanticAgentState: The updated agent as a Pydantic model. + """ + with self.session_maker() as session: + # Retrieve the agent + agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) + + # Replace the environment variables + agent.tool_exec_environment_variables = [ + AgentEnvironmentVariableModel(key=key, value=value, agent_id=agent_id, organization_id=actor.organization_id) + for key, value in env_vars.items() + ] + + # Update the agent in the database + agent.update(session, actor=actor) + + # Return the updated agent state + return agent.to_pydantic() + # ====================================================================================================================== # In Context Messages Management # ====================================================================================================================== diff --git a/letta/services/sandbox_config_manager.py b/letta/services/sandbox_config_manager.py index 9e47612e..0511d3ec 100644 --- a/letta/services/sandbox_config_manager.py +++ b/letta/services/sandbox_config_manager.py @@ -5,11 +5,11 @@ from letta.log import get_logger from letta.orm.errors import NoResultFound from letta.orm.sandbox_config import SandboxConfig as SandboxConfigModel from letta.orm.sandbox_config import SandboxEnvironmentVariable as SandboxEnvVarModel +from letta.schemas.environment_variables import SandboxEnvironmentVariable as PydanticEnvVar +from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate from letta.schemas.sandbox_config import LocalSandboxConfig from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig -from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate -from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar -from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate, SandboxType +from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate, SandboxType from letta.schemas.user import User as PydanticUser from letta.utils import enforce_types, printd diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 803fc98c..a1f13820 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -5,6 +5,7 @@ from letta.functions.functions import parse_source_code from letta.functions.schema_generator import generate_schema from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent from letta.schemas.tool import Tool +from letta.schemas.user import User as PydanticUser def cleanup(client: Union[LocalClient, RESTClient], agent_uuid: str): @@ -27,12 +28,19 @@ def create_tool_from_func(func: callable): ) -def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, UpdateAgent]): +def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, UpdateAgent], actor: PydanticUser): # Assert scalar fields assert agent.system == request.system, f"System prompt mismatch: {agent.system} != {request.system}" assert agent.description == request.description, f"Description mismatch: {agent.description} != {request.description}" assert agent.metadata_ == request.metadata_, f"Metadata mismatch: {agent.metadata_} != {request.metadata_}" + # Assert agent env vars + if hasattr(request, "tool_exec_environment_variables"): + for agent_env_var in agent.tool_exec_environment_variables: + assert agent_env_var.key in request.tool_exec_environment_variables + assert request.tool_exec_environment_variables[agent_env_var.key] == agent_env_var.value + assert agent_env_var.organization_id == actor.organization_id + # Assert agent type if hasattr(request, "agent_type"): assert agent.agent_type == request.agent_type, f"Agent type mismatch: {agent.agent_type} != {request.agent_type}" diff --git a/tests/integration_test_tool_execution_sandbox.py b/tests/integration_test_tool_execution_sandbox.py index 3f64b287..e7824673 100644 --- a/tests/integration_test_tool_execution_sandbox.py +++ b/tests/integration_test_tool_execution_sandbox.py @@ -9,20 +9,14 @@ from sqlalchemy import delete from letta import create_client from letta.functions.function_sets.base import core_memory_append, core_memory_replace -from letta.orm import SandboxConfig, SandboxEnvironmentVariable +from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable from letta.schemas.agent import AgentState from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ChatMemory from letta.schemas.organization import Organization -from letta.schemas.sandbox_config import ( - E2BSandboxConfig, - LocalSandboxConfig, - SandboxConfigCreate, - SandboxConfigUpdate, - SandboxEnvironmentVariableCreate, - SandboxType, -) +from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate, SandboxType from letta.schemas.tool import Tool, ToolCreate from letta.schemas.user import User from letta.services.organization_manager import OrganizationManager diff --git a/tests/test_managers.py b/tests/test_managers.py index 388d477c..2b0ff751 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -35,6 +35,7 @@ from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import BlockUpdate, CreateBlock from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import JobStatus, MessageRole +from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.job import Job as PydanticJob from letta.schemas.job import JobUpdate @@ -43,15 +44,7 @@ from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import MessageCreate, MessageUpdate from letta.schemas.organization import Organization as PydanticOrganization from letta.schemas.passage import Passage as PydanticPassage -from letta.schemas.sandbox_config import ( - E2BSandboxConfig, - LocalSandboxConfig, - SandboxConfigCreate, - SandboxConfigUpdate, - SandboxEnvironmentVariableCreate, - SandboxEnvironmentVariableUpdate, - SandboxType, -) +from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate, SandboxType from letta.schemas.source import Source as PydanticSource from letta.schemas.source import SourceUpdate from letta.schemas.tool import Tool as PydanticTool @@ -413,6 +406,7 @@ def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_too metadata_={"test_key": "test_value"}, tool_rules=[InitToolRule(tool_name=print_tool.name)], initial_message_sequence=[MessageCreate(role=MessageRole.user, text="hello world")], + tool_exec_environment_variables={"test_env_var_key_a": "test_env_var_value_a", "test_env_var_key_b": "test_env_var_value_b"}, ) created_agent = server.agent_manager.create_agent( create_agent_request, @@ -482,20 +476,20 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent): def test_create_get_list_agent(server: SyncServer, comprehensive_test_agent_fixture, default_user): # Test agent creation created_agent, create_agent_request = comprehensive_test_agent_fixture - comprehensive_agent_checks(created_agent, create_agent_request) + comprehensive_agent_checks(created_agent, create_agent_request, actor=default_user) # Test get agent get_agent = server.agent_manager.get_agent_by_id(agent_id=created_agent.id, actor=default_user) - comprehensive_agent_checks(get_agent, create_agent_request) + comprehensive_agent_checks(get_agent, create_agent_request, actor=default_user) # Test get agent name get_agent_name = server.agent_manager.get_agent_by_name(agent_name=created_agent.name, actor=default_user) - comprehensive_agent_checks(get_agent_name, create_agent_request) + comprehensive_agent_checks(get_agent_name, create_agent_request, actor=default_user) # Test list agent list_agents = server.agent_manager.list_agents(actor=default_user) assert len(list_agents) == 1 - comprehensive_agent_checks(list_agents[0], create_agent_request) + comprehensive_agent_checks(list_agents[0], create_agent_request, actor=default_user) # Test deleting the agent server.agent_manager.delete_agent(get_agent.id, default_user) @@ -566,10 +560,11 @@ def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, othe embedding_config=EmbeddingConfig.default_config(model_name="letta"), message_ids=["10", "20"], metadata_={"train_key": "train_value"}, + tool_exec_environment_variables={"new_tool_exec_key": "new_tool_exec_value"}, ) updated_agent = server.agent_manager.update_agent(agent.id, update_agent_request, actor=default_user) - comprehensive_agent_checks(updated_agent, update_agent_request) + comprehensive_agent_checks(updated_agent, update_agent_request, actor=default_user) assert updated_agent.message_ids == update_agent_request.message_ids