feat: Allow per-agent tool execution environment variables (#509)

This commit is contained in:
Matthew Zhou
2025-01-05 19:06:38 -10:00
committed by GitHub
parent 130fd63c6b
commit 0ef692441f
16 changed files with 249 additions and 73 deletions

View File

@@ -0,0 +1,51 @@
"""Add per agent environment variables
Revision ID: 400501b04bf0
Revises: e78b4e82db30
Create Date: 2025-01-04 20:45:28.024690
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "400501b04bf0"
down_revision: Union[str, None] = "e78b4e82db30"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"agent_environment_variables",
sa.Column("id", sa.String(), nullable=False),
sa.Column("key", sa.String(), nullable=False),
sa.Column("value", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.Column("agent_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("key", "agent_id", name="uix_key_agent"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("agent_environment_variables")
# ### end Alembic commands ###

View File

@@ -4,9 +4,10 @@ import uuid
from letta import create_client
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxType
from letta.schemas.sandbox_config import SandboxType
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.settings import tool_settings

View File

@@ -15,6 +15,11 @@ from letta.schemas.embedding_config import EmbeddingConfig
# new schemas
from letta.schemas.enums import JobStatus, MessageRole
from letta.schemas.environment_variables import (
SandboxEnvironmentVariable,
SandboxEnvironmentVariableCreate,
SandboxEnvironmentVariableUpdate,
)
from letta.schemas.file import FileMetadata
from letta.schemas.job import Job
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
@@ -25,16 +30,7 @@ from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.openai.chat_completions import ToolCall
from letta.schemas.organization import Organization
from letta.schemas.passage import Passage
from letta.schemas.sandbox_config import (
E2BSandboxConfig,
LocalSandboxConfig,
SandboxConfig,
SandboxConfigCreate,
SandboxConfigUpdate,
SandboxEnvironmentVariable,
SandboxEnvironmentVariableCreate,
SandboxEnvironmentVariableUpdate,
)
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfig, SandboxConfigCreate, SandboxConfigUpdate
from letta.schemas.source import Source, SourceCreate, SourceUpdate
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
from letta.schemas.tool_rule import BaseToolRule

View File

@@ -31,7 +31,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
# agent generates its own id
# TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase
# TODO: Move this in this PR? at the very end?
# TODO: Some still rely on the Pydantic object to do this
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"agent-{uuid.uuid4()}")
# Descriptor fields
@@ -61,6 +61,13 @@ class Agent(SqlalchemyBase, OrganizationMixin):
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="agents")
tool_exec_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship(
"AgentEnvironmentVariable",
back_populates="agent",
cascade="all, delete-orphan",
lazy="selectin",
doc="Environment variables associated with this agent.",
)
tools: Mapped[List["Tool"]] = relationship("Tool", secondary="tools_agents", lazy="selectin", passive_deletes=True)
sources: Mapped[List["Source"]] = relationship("Source", secondary="sources_agents", lazy="selectin")
core_memory: Mapped[List["Block"]] = relationship("Block", secondary="blocks_agents", lazy="selectin")
@@ -119,5 +126,6 @@ class Agent(SqlalchemyBase, OrganizationMixin):
"last_updated_by_id": self.last_updated_by_id,
"created_at": self.created_at,
"updated_at": self.updated_at,
"tool_exec_environment_variables": self.tool_exec_environment_variables,
}
return self.__pydantic_model__(**state)

View File

@@ -9,6 +9,7 @@ if TYPE_CHECKING:
from letta.orm.agent import Agent
from letta.orm.file import FileMetadata
from letta.orm.sandbox_config import AgentEnvironmentVariable
from letta.orm.tool import Tool
from letta.orm.user import User
@@ -33,6 +34,9 @@ class Organization(SqlalchemyBase):
sandbox_environment_variables: Mapped[List["SandboxEnvironmentVariable"]] = relationship(
"SandboxEnvironmentVariable", back_populates="organization", cascade="all, delete-orphan"
)
agent_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship(
"AgentEnvironmentVariable", back_populates="organization", cascade="all, delete-orphan"
)
# relationships
agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")

View File

@@ -1,3 +1,4 @@
import uuid
from typing import TYPE_CHECKING, Dict, List, Optional
from sqlalchemy import JSON
@@ -5,13 +6,14 @@ from sqlalchemy import Enum as SqlEnum
from sqlalchemy import String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin, SandboxConfigMixin
from letta.orm.mixins import AgentMixin, OrganizationMixin, SandboxConfigMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.environment_variables import SandboxEnvironmentVariable as PydanticSandboxEnvironmentVariable
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticSandboxEnvironmentVariable
from letta.schemas.sandbox_config import SandboxType
if TYPE_CHECKING:
from letta.orm.agent import Agent
from letta.orm.organization import Organization
@@ -52,3 +54,21 @@ class SandboxEnvironmentVariable(SqlalchemyBase, OrganizationMixin, SandboxConfi
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="sandbox_environment_variables")
sandbox_config: Mapped["SandboxConfig"] = relationship("SandboxConfig", back_populates="sandbox_environment_variables")
class AgentEnvironmentVariable(SqlalchemyBase, OrganizationMixin, AgentMixin):
"""ORM model for environment variables associated with agents."""
__tablename__ = "agent_environment_variables"
# We cannot have duplicate key names for the same agent, the env var would get overwritten
__table_args__ = (UniqueConstraint("key", "agent_id", name="uix_key_agent"),)
# agent_env_var generates its own id
# TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"agent-env-{uuid.uuid4()}")
key: Mapped[str] = mapped_column(String, nullable=False, doc="The name of the environment variable.")
value: Mapped[str] = mapped_column(String, nullable=False, doc="The value of the environment variable.")
description: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="An optional description of the environment variable.")
organization: Mapped["Organization"] = relationship("Organization", back_populates="agent_environment_variables")
agent: Mapped[List["Agent"]] = relationship("Agent", back_populates="tool_exec_environment_variables")

View File

@@ -6,6 +6,7 @@ from pydantic import BaseModel, Field, field_validator
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
from letta.schemas.block import CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.environment_variables import AgentEnvironmentVariable
from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
@@ -78,6 +79,9 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
tools: List[Tool] = Field(..., description="The tools used by the agent.")
sources: List[Source] = Field(..., description="The sources used by the agent.")
tags: List[str] = Field(..., description="The tags associated with the agent.")
tool_exec_environment_variables: List[AgentEnvironmentVariable] = Field(
..., description="The environment variables for tool execution specific to this agent."
)
class CreateAgent(BaseModel, validate_assignment=True): #
@@ -120,6 +124,9 @@ class CreateAgent(BaseModel, validate_assignment=True): #
embedding_chunk_size: Optional[int] = Field(DEFAULT_EMBEDDING_CHUNK_SIZE, description="The embedding chunk size used by the agent.")
from_template: Optional[str] = Field(None, description="The template id used to configure the agent")
project_id: Optional[str] = Field(None, description="The project id that the agent will be associated with.")
tool_exec_environment_variables: Optional[Dict[str, str]] = Field(
None, description="The environment variables for tool execution specific to this agent."
)
@field_validator("name")
@classmethod
@@ -184,6 +191,9 @@ class UpdateAgent(BaseModel):
message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
description: Optional[str] = Field(None, description="The description of the agent.")
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
tool_exec_environment_variables: Optional[Dict[str, str]] = Field(
None, description="The environment variables for tool execution specific to this agent."
)
class Config:
extra = "ignore" # Ignores extra fields

View File

@@ -0,0 +1,62 @@
from typing import Optional
from pydantic import Field
from letta.schemas.letta_base import LettaBase, OrmMetadataBase
# Base Environment Variable
class EnvironmentVariableBase(OrmMetadataBase):
id: str = Field(..., description="The unique identifier for the environment variable.")
key: str = Field(..., description="The name of the environment variable.")
value: str = Field(..., description="The value of the environment variable.")
description: Optional[str] = Field(None, description="An optional description of the environment variable.")
organization_id: Optional[str] = Field(None, description="The ID of the organization this environment variable belongs to.")
class EnvironmentVariableCreateBase(LettaBase):
key: str = Field(..., description="The name of the environment variable.")
value: str = Field(..., description="The value of the environment variable.")
description: Optional[str] = Field(None, description="An optional description of the environment variable.")
class EnvironmentVariableUpdateBase(LettaBase):
key: Optional[str] = Field(None, description="The name of the environment variable.")
value: Optional[str] = Field(None, description="The value of the environment variable.")
description: Optional[str] = Field(None, description="An optional description of the environment variable.")
# Sandbox-Specific Environment Variable
class SandboxEnvironmentVariableBase(EnvironmentVariableBase):
__id_prefix__ = "sandbox-env"
sandbox_config_id: str = Field(..., description="The ID of the sandbox config this environment variable belongs to.")
class SandboxEnvironmentVariable(SandboxEnvironmentVariableBase):
id: str = SandboxEnvironmentVariableBase.generate_id_field()
class SandboxEnvironmentVariableCreate(EnvironmentVariableCreateBase):
pass
class SandboxEnvironmentVariableUpdate(EnvironmentVariableUpdateBase):
pass
# Agent-Specific Environment Variable
class AgentEnvironmentVariableBase(EnvironmentVariableBase):
__id_prefix__ = "agent-env"
agent_id: str = Field(..., description="The ID of the agent this environment variable belongs to.")
class AgentEnvironmentVariable(AgentEnvironmentVariableBase):
id: str = AgentEnvironmentVariableBase.generate_id_field()
class AgentEnvironmentVariableCreate(EnvironmentVariableCreateBase):
pass
class AgentEnvironmentVariableUpdate(EnvironmentVariableUpdateBase):
pass

View File

@@ -102,31 +102,3 @@ class SandboxConfigUpdate(LettaBase):
"""Pydantic model for updating SandboxConfig fields."""
config: Union[LocalSandboxConfig, E2BSandboxConfig] = Field(None, description="The JSON configuration data for the sandbox.")
# Environment Variable
class SandboxEnvironmentVariableBase(OrmMetadataBase):
__id_prefix__ = "sandbox-env"
class SandboxEnvironmentVariable(SandboxEnvironmentVariableBase):
id: str = SandboxEnvironmentVariableBase.generate_id_field()
key: str = Field(..., description="The name of the environment variable.")
value: str = Field(..., description="The value of the environment variable.")
description: Optional[str] = Field(None, description="An optional description of the environment variable.")
sandbox_config_id: str = Field(..., description="The ID of the sandbox config this environment variable belongs to.")
organization_id: Optional[str] = Field(None, description="The ID of the organization this environment variable belongs to.")
class SandboxEnvironmentVariableCreate(LettaBase):
key: str = Field(..., description="The name of the environment variable.")
value: str = Field(..., description="The value of the environment variable.")
description: Optional[str] = Field(None, description="An optional description of the environment variable.")
class SandboxEnvironmentVariableUpdate(LettaBase):
"""Pydantic model for updating SandboxEnvironmentVariable fields."""
key: Optional[str] = Field(None, description="The name of the environment variable.")
value: Optional[str] = Field(None, description="The value of the environment variable.")
description: Optional[str] = Field(None, description="An optional description of the environment variable.")

View File

@@ -2,10 +2,10 @@ from typing import List, Optional
from fastapi import APIRouter, Depends, Query
from letta.schemas.environment_variables import SandboxEnvironmentVariable as PydanticEnvVar
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate
from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar
from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate, SandboxType
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate, SandboxType
from letta.server.rest_api.utils import get_letta_server, get_user_id
from letta.server.server import SyncServer

View File

@@ -46,6 +46,7 @@ from letta.schemas.embedding_config import EmbeddingConfig
# openai schemas
from letta.schemas.enums import JobStatus
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
from letta.schemas.job import Job, JobUpdate
from letta.schemas.letta_message import LettaMessage, ToolReturnMessage
from letta.schemas.llm_config import LLMConfig
@@ -53,7 +54,7 @@ from letta.schemas.memory import ArchivalMemorySummary, ContextWindowOverview, M
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate
from letta.schemas.organization import Organization
from letta.schemas.passage import Passage
from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxType
from letta.schemas.sandbox_config import SandboxType
from letta.schemas.source import Source
from letta.schemas.tool import Tool
from letta.schemas.usage import LettaUsageStatistics

View File

@@ -14,6 +14,7 @@ from letta.orm import Source as SourceModel
from letta.orm import SourcePassage, SourcesAgents
from letta.orm import Tool as ToolModel
from letta.orm.errors import NoResultFound
from letta.orm.sandbox_config import AgentEnvironmentVariable as AgentEnvironmentVariableModel
from letta.orm.sqlite_functions import adapt_array
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
@@ -116,6 +117,14 @@ class AgentManager:
actor=actor,
)
# If there are provided environment variables, add them in
if agent_create.tool_exec_environment_variables:
agent_state = self._set_environment_variables(
agent_id=agent_state.id,
env_vars=agent_create.tool_exec_environment_variables,
actor=actor,
)
# TODO: See if we can merge this into the above SQL create call for performance reasons
# Generate a sequence of initial messages to put in the buffer
init_messages = initialize_message_sequence(
@@ -192,6 +201,14 @@ class AgentManager:
def update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState:
agent_state = self._update_agent(agent_id=agent_id, agent_update=agent_update, actor=actor)
# If there are provided environment variables, add them in
if agent_update.tool_exec_environment_variables:
agent_state = self._set_environment_variables(
agent_id=agent_state.id,
env_vars=agent_update.tool_exec_environment_variables,
actor=actor,
)
# Rebuild the system prompt if it's different
if agent_update.system and agent_update.system != agent_state.system:
agent_state = self.rebuild_system_prompt(agent_id=agent_state.id, actor=actor, force=True, update_timestamp=False)
@@ -296,6 +313,43 @@ class AgentManager:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
agent.hard_delete(session)
# ======================================================================================================================
# Per Agent Environment Variable Management
# ======================================================================================================================
@enforce_types
def _set_environment_variables(
self,
agent_id: str,
env_vars: Dict[str, str],
actor: PydanticUser,
) -> PydanticAgentState:
"""
Adds or replaces the environment variables for the specified agent.
Args:
agent_id: The agent id.
env_vars: A dictionary of environment variable key-value pairs.
actor: The user performing the action.
Returns:
PydanticAgentState: The updated agent as a Pydantic model.
"""
with self.session_maker() as session:
# Retrieve the agent
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# Replace the environment variables
agent.tool_exec_environment_variables = [
AgentEnvironmentVariableModel(key=key, value=value, agent_id=agent_id, organization_id=actor.organization_id)
for key, value in env_vars.items()
]
# Update the agent in the database
agent.update(session, actor=actor)
# Return the updated agent state
return agent.to_pydantic()
# ======================================================================================================================
# In Context Messages Management
# ======================================================================================================================

View File

@@ -5,11 +5,11 @@ from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.orm.sandbox_config import SandboxConfig as SandboxConfigModel
from letta.orm.sandbox_config import SandboxEnvironmentVariable as SandboxEnvVarModel
from letta.schemas.environment_variables import SandboxEnvironmentVariable as PydanticEnvVar
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
from letta.schemas.sandbox_config import LocalSandboxConfig
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate
from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar
from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate, SandboxType
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate, SandboxType
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types, printd

View File

@@ -5,6 +5,7 @@ from letta.functions.functions import parse_source_code
from letta.functions.schema_generator import generate_schema
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
from letta.schemas.tool import Tool
from letta.schemas.user import User as PydanticUser
def cleanup(client: Union[LocalClient, RESTClient], agent_uuid: str):
@@ -27,12 +28,19 @@ def create_tool_from_func(func: callable):
)
def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, UpdateAgent]):
def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, UpdateAgent], actor: PydanticUser):
# Assert scalar fields
assert agent.system == request.system, f"System prompt mismatch: {agent.system} != {request.system}"
assert agent.description == request.description, f"Description mismatch: {agent.description} != {request.description}"
assert agent.metadata_ == request.metadata_, f"Metadata mismatch: {agent.metadata_} != {request.metadata_}"
# Assert agent env vars
if hasattr(request, "tool_exec_environment_variables"):
for agent_env_var in agent.tool_exec_environment_variables:
assert agent_env_var.key in request.tool_exec_environment_variables
assert request.tool_exec_environment_variables[agent_env_var.key] == agent_env_var.value
assert agent_env_var.organization_id == actor.organization_id
# Assert agent type
if hasattr(request, "agent_type"):
assert agent.agent_type == request.agent_type, f"Agent type mismatch: {agent.agent_type} != {request.agent_type}"

View File

@@ -9,20 +9,14 @@ from sqlalchemy import delete
from letta import create_client
from letta.functions.function_sets.base import core_memory_append, core_memory_replace
from letta.orm import SandboxConfig, SandboxEnvironmentVariable
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
from letta.schemas.organization import Organization
from letta.schemas.sandbox_config import (
E2BSandboxConfig,
LocalSandboxConfig,
SandboxConfigCreate,
SandboxConfigUpdate,
SandboxEnvironmentVariableCreate,
SandboxType,
)
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate, SandboxType
from letta.schemas.tool import Tool, ToolCreate
from letta.schemas.user import User
from letta.services.organization_manager import OrganizationManager

View File

@@ -35,6 +35,7 @@ from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import BlockUpdate, CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import JobStatus, MessageRole
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
from letta.schemas.file import FileMetadata as PydanticFileMetadata
from letta.schemas.job import Job as PydanticJob
from letta.schemas.job import JobUpdate
@@ -43,15 +44,7 @@ from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import MessageCreate, MessageUpdate
from letta.schemas.organization import Organization as PydanticOrganization
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.sandbox_config import (
E2BSandboxConfig,
LocalSandboxConfig,
SandboxConfigCreate,
SandboxConfigUpdate,
SandboxEnvironmentVariableCreate,
SandboxEnvironmentVariableUpdate,
SandboxType,
)
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate, SandboxType
from letta.schemas.source import Source as PydanticSource
from letta.schemas.source import SourceUpdate
from letta.schemas.tool import Tool as PydanticTool
@@ -413,6 +406,7 @@ def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_too
metadata_={"test_key": "test_value"},
tool_rules=[InitToolRule(tool_name=print_tool.name)],
initial_message_sequence=[MessageCreate(role=MessageRole.user, text="hello world")],
tool_exec_environment_variables={"test_env_var_key_a": "test_env_var_value_a", "test_env_var_key_b": "test_env_var_value_b"},
)
created_agent = server.agent_manager.create_agent(
create_agent_request,
@@ -482,20 +476,20 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent):
def test_create_get_list_agent(server: SyncServer, comprehensive_test_agent_fixture, default_user):
# Test agent creation
created_agent, create_agent_request = comprehensive_test_agent_fixture
comprehensive_agent_checks(created_agent, create_agent_request)
comprehensive_agent_checks(created_agent, create_agent_request, actor=default_user)
# Test get agent
get_agent = server.agent_manager.get_agent_by_id(agent_id=created_agent.id, actor=default_user)
comprehensive_agent_checks(get_agent, create_agent_request)
comprehensive_agent_checks(get_agent, create_agent_request, actor=default_user)
# Test get agent name
get_agent_name = server.agent_manager.get_agent_by_name(agent_name=created_agent.name, actor=default_user)
comprehensive_agent_checks(get_agent_name, create_agent_request)
comprehensive_agent_checks(get_agent_name, create_agent_request, actor=default_user)
# Test list agent
list_agents = server.agent_manager.list_agents(actor=default_user)
assert len(list_agents) == 1
comprehensive_agent_checks(list_agents[0], create_agent_request)
comprehensive_agent_checks(list_agents[0], create_agent_request, actor=default_user)
# Test deleting the agent
server.agent_manager.delete_agent(get_agent.id, default_user)
@@ -566,10 +560,11 @@ def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, othe
embedding_config=EmbeddingConfig.default_config(model_name="letta"),
message_ids=["10", "20"],
metadata_={"train_key": "train_value"},
tool_exec_environment_variables={"new_tool_exec_key": "new_tool_exec_value"},
)
updated_agent = server.agent_manager.update_agent(agent.id, update_agent_request, actor=default_user)
comprehensive_agent_checks(updated_agent, update_agent_request)
comprehensive_agent_checks(updated_agent, update_agent_request, actor=default_user)
assert updated_agent.message_ids == update_agent_request.message_ids