chore: Improved sandboxing support (#2333)
Co-authored-by: Shubham Naik <shub@memgpt.ai> Co-authored-by: Shubham Naik <shubham.naik10@gmail.com> Co-authored-by: Caren Thomas <caren@letta.com> Co-authored-by: Sarah Wooders <sarahwooders@gmail.com> Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
@@ -0,0 +1,51 @@
|
||||
"""Add per agent environment variables
|
||||
|
||||
Revision ID: 400501b04bf0
|
||||
Revises: e78b4e82db30
|
||||
Create Date: 2025-01-04 20:45:28.024690
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "400501b04bf0"
|
||||
down_revision: Union[str, None] = "e78b4e82db30"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"agent_environment_variables",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("key", sa.String(), nullable=False),
|
||||
sa.Column("value", sa.String(), nullable=False),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.Column("agent_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["agent_id"], ["agents.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("key", "agent_id", name="uix_key_agent"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("agent_environment_variables")
|
||||
# ### end Alembic commands ###
|
||||
@@ -4,9 +4,10 @@ import uuid
|
||||
|
||||
from letta import create_client
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ChatMemory
|
||||
from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxType
|
||||
from letta.schemas.sandbox_config import SandboxType
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.settings import tool_settings
|
||||
|
||||
|
||||
@@ -15,6 +15,11 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
||||
|
||||
# new schemas
|
||||
from letta.schemas.enums import JobStatus, MessageRole
|
||||
from letta.schemas.environment_variables import (
|
||||
SandboxEnvironmentVariable,
|
||||
SandboxEnvironmentVariableCreate,
|
||||
SandboxEnvironmentVariableUpdate,
|
||||
)
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.job import Job
|
||||
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
|
||||
@@ -25,16 +30,7 @@ from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||
from letta.schemas.openai.chat_completions import ToolCall
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.sandbox_config import (
|
||||
E2BSandboxConfig,
|
||||
LocalSandboxConfig,
|
||||
SandboxConfig,
|
||||
SandboxConfigCreate,
|
||||
SandboxConfigUpdate,
|
||||
SandboxEnvironmentVariable,
|
||||
SandboxEnvironmentVariableCreate,
|
||||
SandboxEnvironmentVariableUpdate,
|
||||
)
|
||||
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfig, SandboxConfigCreate, SandboxConfigUpdate
|
||||
from letta.schemas.source import Source, SourceCreate, SourceUpdate
|
||||
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
|
||||
from letta.schemas.tool_rule import BaseToolRule
|
||||
|
||||
@@ -45,7 +45,6 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
|
||||
# break
|
||||
if sse.data in [status.value for status in MessageStreamStatus]:
|
||||
# break
|
||||
# print("sse.data::", sse.data)
|
||||
yield MessageStreamStatus(sse.data)
|
||||
else:
|
||||
chunk_data = json.loads(sse.data)
|
||||
|
||||
@@ -31,7 +31,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
# agent generates its own id
|
||||
# TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase
|
||||
# TODO: Move this in this PR? at the very end?
|
||||
# TODO: Some still rely on the Pydantic object to do this
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"agent-{uuid.uuid4()}")
|
||||
|
||||
# Descriptor fields
|
||||
@@ -61,6 +61,13 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="agents")
|
||||
tool_exec_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship(
|
||||
"AgentEnvironmentVariable",
|
||||
back_populates="agent",
|
||||
cascade="all, delete-orphan",
|
||||
lazy="selectin",
|
||||
doc="Environment variables associated with this agent.",
|
||||
)
|
||||
tools: Mapped[List["Tool"]] = relationship("Tool", secondary="tools_agents", lazy="selectin", passive_deletes=True)
|
||||
sources: Mapped[List["Source"]] = relationship("Source", secondary="sources_agents", lazy="selectin")
|
||||
core_memory: Mapped[List["Block"]] = relationship("Block", secondary="blocks_agents", lazy="selectin")
|
||||
@@ -119,5 +126,6 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
"last_updated_by_id": self.last_updated_by_id,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
"tool_exec_environment_variables": self.tool_exec_environment_variables,
|
||||
}
|
||||
return self.__pydantic_model__(**state)
|
||||
|
||||
@@ -9,6 +9,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from letta.orm.agent import Agent
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable
|
||||
from letta.orm.tool import Tool
|
||||
from letta.orm.user import User
|
||||
|
||||
@@ -33,6 +34,9 @@ class Organization(SqlalchemyBase):
|
||||
sandbox_environment_variables: Mapped[List["SandboxEnvironmentVariable"]] = relationship(
|
||||
"SandboxEnvironmentVariable", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
agent_environment_variables: Mapped[List["AgentEnvironmentVariable"]] = relationship(
|
||||
"AgentEnvironmentVariable", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# relationships
|
||||
agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import JSON
|
||||
@@ -5,13 +6,14 @@ from sqlalchemy import Enum as SqlEnum
|
||||
from sqlalchemy import String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin, SandboxConfigMixin
|
||||
from letta.orm.mixins import AgentMixin, OrganizationMixin, SandboxConfigMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariable as PydanticSandboxEnvironmentVariable
|
||||
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
|
||||
from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticSandboxEnvironmentVariable
|
||||
from letta.schemas.sandbox_config import SandboxType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.agent import Agent
|
||||
from letta.orm.organization import Organization
|
||||
|
||||
|
||||
@@ -52,3 +54,21 @@ class SandboxEnvironmentVariable(SqlalchemyBase, OrganizationMixin, SandboxConfi
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="sandbox_environment_variables")
|
||||
sandbox_config: Mapped["SandboxConfig"] = relationship("SandboxConfig", back_populates="sandbox_environment_variables")
|
||||
|
||||
|
||||
class AgentEnvironmentVariable(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
"""ORM model for environment variables associated with agents."""
|
||||
|
||||
__tablename__ = "agent_environment_variables"
|
||||
# We cannot have duplicate key names for the same agent, the env var would get overwritten
|
||||
__table_args__ = (UniqueConstraint("key", "agent_id", name="uix_key_agent"),)
|
||||
|
||||
# agent_env_var generates its own id
|
||||
# TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"agent-env-{uuid.uuid4()}")
|
||||
key: Mapped[str] = mapped_column(String, nullable=False, doc="The name of the environment variable.")
|
||||
value: Mapped[str] = mapped_column(String, nullable=False, doc="The value of the environment variable.")
|
||||
description: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="An optional description of the environment variable.")
|
||||
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="agent_environment_variables")
|
||||
agent: Mapped[List["Agent"]] = relationship("Agent", back_populates="tool_exec_environment_variables")
|
||||
|
||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from letta.constants import DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.environment_variables import AgentEnvironmentVariable
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import Memory
|
||||
@@ -78,6 +79,9 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
|
||||
tools: List[Tool] = Field(..., description="The tools used by the agent.")
|
||||
sources: List[Source] = Field(..., description="The sources used by the agent.")
|
||||
tags: List[str] = Field(..., description="The tags associated with the agent.")
|
||||
tool_exec_environment_variables: List[AgentEnvironmentVariable] = Field(
|
||||
..., description="The environment variables for tool execution specific to this agent."
|
||||
)
|
||||
|
||||
|
||||
class CreateAgent(BaseModel, validate_assignment=True): #
|
||||
@@ -120,6 +124,9 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
||||
embedding_chunk_size: Optional[int] = Field(DEFAULT_EMBEDDING_CHUNK_SIZE, description="The embedding chunk size used by the agent.")
|
||||
from_template: Optional[str] = Field(None, description="The template id used to configure the agent")
|
||||
project_id: Optional[str] = Field(None, description="The project id that the agent will be associated with.")
|
||||
tool_exec_environment_variables: Optional[Dict[str, str]] = Field(
|
||||
None, description="The environment variables for tool execution specific to this agent."
|
||||
)
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
@@ -184,6 +191,9 @@ class UpdateAgent(BaseModel):
|
||||
message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
|
||||
description: Optional[str] = Field(None, description="The description of the agent.")
|
||||
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
|
||||
tool_exec_environment_variables: Optional[Dict[str, str]] = Field(
|
||||
None, description="The environment variables for tool execution specific to this agent."
|
||||
)
|
||||
|
||||
class Config:
|
||||
extra = "ignore" # Ignores extra fields
|
||||
|
||||
@@ -30,8 +30,8 @@ class JobStatus(str, Enum):
|
||||
|
||||
|
||||
class MessageStreamStatus(str, Enum):
|
||||
done_generation = "[DONE_GEN]"
|
||||
done_step = "[DONE_STEP]"
|
||||
# done_generation = "[DONE_GEN]"
|
||||
# done_step = "[DONE_STEP]"
|
||||
done = "[DONE]"
|
||||
|
||||
|
||||
|
||||
62
letta/schemas/environment_variables.py
Normal file
62
letta/schemas/environment_variables.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.letta_base import LettaBase, OrmMetadataBase
|
||||
|
||||
|
||||
# Base Environment Variable
|
||||
class EnvironmentVariableBase(OrmMetadataBase):
|
||||
id: str = Field(..., description="The unique identifier for the environment variable.")
|
||||
key: str = Field(..., description="The name of the environment variable.")
|
||||
value: str = Field(..., description="The value of the environment variable.")
|
||||
description: Optional[str] = Field(None, description="An optional description of the environment variable.")
|
||||
organization_id: Optional[str] = Field(None, description="The ID of the organization this environment variable belongs to.")
|
||||
|
||||
|
||||
class EnvironmentVariableCreateBase(LettaBase):
|
||||
key: str = Field(..., description="The name of the environment variable.")
|
||||
value: str = Field(..., description="The value of the environment variable.")
|
||||
description: Optional[str] = Field(None, description="An optional description of the environment variable.")
|
||||
|
||||
|
||||
class EnvironmentVariableUpdateBase(LettaBase):
|
||||
key: Optional[str] = Field(None, description="The name of the environment variable.")
|
||||
value: Optional[str] = Field(None, description="The value of the environment variable.")
|
||||
description: Optional[str] = Field(None, description="An optional description of the environment variable.")
|
||||
|
||||
|
||||
# Sandbox-Specific Environment Variable
|
||||
class SandboxEnvironmentVariableBase(EnvironmentVariableBase):
|
||||
__id_prefix__ = "sandbox-env"
|
||||
sandbox_config_id: str = Field(..., description="The ID of the sandbox config this environment variable belongs to.")
|
||||
|
||||
|
||||
class SandboxEnvironmentVariable(SandboxEnvironmentVariableBase):
|
||||
id: str = SandboxEnvironmentVariableBase.generate_id_field()
|
||||
|
||||
|
||||
class SandboxEnvironmentVariableCreate(EnvironmentVariableCreateBase):
|
||||
pass
|
||||
|
||||
|
||||
class SandboxEnvironmentVariableUpdate(EnvironmentVariableUpdateBase):
|
||||
pass
|
||||
|
||||
|
||||
# Agent-Specific Environment Variable
|
||||
class AgentEnvironmentVariableBase(EnvironmentVariableBase):
|
||||
__id_prefix__ = "agent-env"
|
||||
agent_id: str = Field(..., description="The ID of the agent this environment variable belongs to.")
|
||||
|
||||
|
||||
class AgentEnvironmentVariable(AgentEnvironmentVariableBase):
|
||||
id: str = AgentEnvironmentVariableBase.generate_id_field()
|
||||
|
||||
|
||||
class AgentEnvironmentVariableCreate(EnvironmentVariableCreateBase):
|
||||
pass
|
||||
|
||||
|
||||
class AgentEnvironmentVariableUpdate(EnvironmentVariableUpdateBase):
|
||||
pass
|
||||
@@ -29,19 +29,20 @@ class LettaResponse(BaseModel):
|
||||
json_schema_extra={
|
||||
"items": {
|
||||
"oneOf": [
|
||||
{"x-ref-name": "SystemMessage"},
|
||||
{"x-ref-name": "UserMessage"},
|
||||
{"x-ref-name": "ReasoningMessage"},
|
||||
{"x-ref-name": "ToolCallMessage"},
|
||||
{"x-ref-name": "ToolReturnMessage"},
|
||||
{"x-ref-name": "AssistantMessage"},
|
||||
{"$ref": "#/components/schemas/SystemMessage-Output"},
|
||||
{"$ref": "#/components/schemas/UserMessage-Output"},
|
||||
{"$ref": "#/components/schemas/ReasoningMessage"},
|
||||
{"$ref": "#/components/schemas/ToolCallMessage"},
|
||||
{"$ref": "#/components/schemas/ToolReturnMessage"},
|
||||
{"$ref": "#/components/schemas/AssistantMessage-Output"},
|
||||
],
|
||||
"discriminator": {"propertyName": "message_type"},
|
||||
}
|
||||
},
|
||||
)
|
||||
usage: LettaUsageStatistics = Field(
|
||||
..., description="The usage statistics of the agent.", json_schema_extra={"x-ref-name": "LettaUsageStatistics"}
|
||||
...,
|
||||
description="The usage statistics of the agent.",
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
|
||||
@@ -102,31 +102,3 @@ class SandboxConfigUpdate(LettaBase):
|
||||
"""Pydantic model for updating SandboxConfig fields."""
|
||||
|
||||
config: Union[LocalSandboxConfig, E2BSandboxConfig] = Field(None, description="The JSON configuration data for the sandbox.")
|
||||
|
||||
|
||||
# Environment Variable
|
||||
class SandboxEnvironmentVariableBase(OrmMetadataBase):
|
||||
__id_prefix__ = "sandbox-env"
|
||||
|
||||
|
||||
class SandboxEnvironmentVariable(SandboxEnvironmentVariableBase):
|
||||
id: str = SandboxEnvironmentVariableBase.generate_id_field()
|
||||
key: str = Field(..., description="The name of the environment variable.")
|
||||
value: str = Field(..., description="The value of the environment variable.")
|
||||
description: Optional[str] = Field(None, description="An optional description of the environment variable.")
|
||||
sandbox_config_id: str = Field(..., description="The ID of the sandbox config this environment variable belongs to.")
|
||||
organization_id: Optional[str] = Field(None, description="The ID of the organization this environment variable belongs to.")
|
||||
|
||||
|
||||
class SandboxEnvironmentVariableCreate(LettaBase):
|
||||
key: str = Field(..., description="The name of the environment variable.")
|
||||
value: str = Field(..., description="The value of the environment variable.")
|
||||
description: Optional[str] = Field(None, description="An optional description of the environment variable.")
|
||||
|
||||
|
||||
class SandboxEnvironmentVariableUpdate(LettaBase):
|
||||
"""Pydantic model for updating SandboxEnvironmentVariable fields."""
|
||||
|
||||
key: Optional[str] = Field(None, description="The name of the environment variable.")
|
||||
value: Optional[str] = Field(None, description="The value of the environment variable.")
|
||||
description: Optional[str] = Field(None, description="An optional description of the environment variable.")
|
||||
|
||||
@@ -292,8 +292,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# if multi_step = True, the stream ends when the agent yields
|
||||
# if multi_step = False, the stream ends when the step ends
|
||||
self.multi_step = multi_step
|
||||
self.multi_step_indicator = MessageStreamStatus.done_step
|
||||
self.multi_step_gen_indicator = MessageStreamStatus.done_generation
|
||||
# self.multi_step_indicator = MessageStreamStatus.done_step
|
||||
# self.multi_step_gen_indicator = MessageStreamStatus.done_generation
|
||||
|
||||
# Support for AssistantMessage
|
||||
self.use_assistant_message = False # TODO: Remove this
|
||||
@@ -378,8 +378,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
"""Clean up the stream by deactivating and clearing chunks."""
|
||||
self.streaming_chat_completion_mode_function_name = None
|
||||
|
||||
if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
|
||||
self._push_to_buffer(self.multi_step_gen_indicator)
|
||||
# if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
|
||||
# self._push_to_buffer(self.multi_step_gen_indicator)
|
||||
|
||||
# Wipe the inner thoughts buffers
|
||||
self._reset_inner_thoughts_json_reader()
|
||||
@@ -390,9 +390,9 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# end the stream
|
||||
self._active = False
|
||||
self._event.set() # Unblock the generator if it's waiting to allow it to complete
|
||||
elif not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
|
||||
# signal that a new step has started in the stream
|
||||
self._push_to_buffer(self.multi_step_indicator)
|
||||
# elif not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
|
||||
# # signal that a new step has started in the stream
|
||||
# self._push_to_buffer(self.multi_step_indicator)
|
||||
|
||||
# Wipe the inner thoughts buffers
|
||||
self._reset_inner_thoughts_json_reader()
|
||||
|
||||
@@ -2,10 +2,10 @@ from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariable as PydanticEnvVar
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
|
||||
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
|
||||
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate
|
||||
from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar
|
||||
from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate, SandboxType
|
||||
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate, SandboxType
|
||||
from letta.server.rest_api.utils import get_letta_server, get_user_id
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
||||
|
||||
# openai schemas
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
|
||||
from letta.schemas.job import Job, JobUpdate
|
||||
from letta.schemas.letta_message import LettaMessage, ToolReturnMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
@@ -53,7 +54,7 @@ from letta.schemas.memory import ArchivalMemorySummary, ContextWindowOverview, M
|
||||
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxType
|
||||
from letta.schemas.sandbox_config import SandboxType
|
||||
from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
@@ -14,6 +14,7 @@ from letta.orm import Source as SourceModel
|
||||
from letta.orm import SourcePassage, SourcesAgents
|
||||
from letta.orm import Tool as ToolModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.sandbox_config import AgentEnvironmentVariable as AgentEnvironmentVariableModel
|
||||
from letta.orm.sqlite_functions import adapt_array
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
|
||||
@@ -116,6 +117,14 @@ class AgentManager:
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# If there are provided environment variables, add them in
|
||||
if agent_create.tool_exec_environment_variables:
|
||||
agent_state = self._set_environment_variables(
|
||||
agent_id=agent_state.id,
|
||||
env_vars=agent_create.tool_exec_environment_variables,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# TODO: See if we can merge this into the above SQL create call for performance reasons
|
||||
# Generate a sequence of initial messages to put in the buffer
|
||||
init_messages = initialize_message_sequence(
|
||||
@@ -192,6 +201,14 @@ class AgentManager:
|
||||
def update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState:
|
||||
agent_state = self._update_agent(agent_id=agent_id, agent_update=agent_update, actor=actor)
|
||||
|
||||
# If there are provided environment variables, add them in
|
||||
if agent_update.tool_exec_environment_variables:
|
||||
agent_state = self._set_environment_variables(
|
||||
agent_id=agent_state.id,
|
||||
env_vars=agent_update.tool_exec_environment_variables,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Rebuild the system prompt if it's different
|
||||
if agent_update.system and agent_update.system != agent_state.system:
|
||||
agent_state = self.rebuild_system_prompt(agent_id=agent_state.id, actor=actor, force=True, update_timestamp=False)
|
||||
@@ -296,6 +313,43 @@ class AgentManager:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
agent.hard_delete(session)
|
||||
|
||||
# ======================================================================================================================
|
||||
# Per Agent Environment Variable Management
|
||||
# ======================================================================================================================
|
||||
@enforce_types
|
||||
def _set_environment_variables(
|
||||
self,
|
||||
agent_id: str,
|
||||
env_vars: Dict[str, str],
|
||||
actor: PydanticUser,
|
||||
) -> PydanticAgentState:
|
||||
"""
|
||||
Adds or replaces the environment variables for the specified agent.
|
||||
|
||||
Args:
|
||||
agent_id: The agent id.
|
||||
env_vars: A dictionary of environment variable key-value pairs.
|
||||
actor: The user performing the action.
|
||||
|
||||
Returns:
|
||||
PydanticAgentState: The updated agent as a Pydantic model.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# Retrieve the agent
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# Replace the environment variables
|
||||
agent.tool_exec_environment_variables = [
|
||||
AgentEnvironmentVariableModel(key=key, value=value, agent_id=agent_id, organization_id=actor.organization_id)
|
||||
for key, value in env_vars.items()
|
||||
]
|
||||
|
||||
# Update the agent in the database
|
||||
agent.update(session, actor=actor)
|
||||
|
||||
# Return the updated agent state
|
||||
return agent.to_pydantic()
|
||||
|
||||
# ======================================================================================================================
|
||||
# In Context Messages Management
|
||||
# ======================================================================================================================
|
||||
|
||||
@@ -5,11 +5,11 @@ from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.sandbox_config import SandboxConfig as SandboxConfigModel
|
||||
from letta.orm.sandbox_config import SandboxEnvironmentVariable as SandboxEnvVarModel
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariable as PydanticEnvVar
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
|
||||
from letta.schemas.sandbox_config import LocalSandboxConfig
|
||||
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
|
||||
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate
|
||||
from letta.schemas.sandbox_config import SandboxEnvironmentVariable as PydanticEnvVar
|
||||
from letta.schemas.sandbox_config import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate, SandboxType
|
||||
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate, SandboxType
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types, printd
|
||||
|
||||
|
||||
@@ -278,11 +278,14 @@ class ToolExecutionSandbox:
|
||||
sbx = self.get_running_e2b_sandbox_with_same_state(sbx_config)
|
||||
if not sbx or self.force_recreate:
|
||||
if not sbx:
|
||||
logger.info(f"No running e2b sandbox found with the same state: {sbx_config}")
|
||||
logger.info(f"No running e2b sandbox found with the same state: {sbx_config}")
|
||||
else:
|
||||
logger.info(f"Force recreated e2b sandbox with state: {sbx_config}")
|
||||
logger.info(f"Force recreated e2b sandbox with state: {sbx_config}")
|
||||
sbx = self.create_e2b_sandbox_with_metadata_hash(sandbox_config=sbx_config)
|
||||
|
||||
logger.info(f"E2B Sandbox configurations: {sbx_config}")
|
||||
logger.info(f"E2B Sandbox ID: {sbx.sandbox_id}")
|
||||
|
||||
# Since this sandbox was used, we extend its lifecycle by the timeout
|
||||
sbx.set_timeout(sbx_config.get_e2b_config().timeout)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from letta.functions.functions import parse_source_code
|
||||
from letta.functions.schema_generator import generate_schema
|
||||
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
|
||||
|
||||
def cleanup(client: Union[LocalClient, RESTClient], agent_uuid: str):
|
||||
@@ -27,12 +28,19 @@ def create_tool_from_func(func: callable):
|
||||
)
|
||||
|
||||
|
||||
def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, UpdateAgent]):
|
||||
def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, UpdateAgent], actor: PydanticUser):
|
||||
# Assert scalar fields
|
||||
assert agent.system == request.system, f"System prompt mismatch: {agent.system} != {request.system}"
|
||||
assert agent.description == request.description, f"Description mismatch: {agent.description} != {request.description}"
|
||||
assert agent.metadata_ == request.metadata_, f"Metadata mismatch: {agent.metadata_} != {request.metadata_}"
|
||||
|
||||
# Assert agent env vars
|
||||
if hasattr(request, "tool_exec_environment_variables"):
|
||||
for agent_env_var in agent.tool_exec_environment_variables:
|
||||
assert agent_env_var.key in request.tool_exec_environment_variables
|
||||
assert request.tool_exec_environment_variables[agent_env_var.key] == agent_env_var.value
|
||||
assert agent_env_var.organization_id == actor.organization_id
|
||||
|
||||
# Assert agent type
|
||||
if hasattr(request, "agent_type"):
|
||||
assert agent.agent_type == request.agent_type, f"Agent type mismatch: {agent.agent_type} != {request.agent_type}"
|
||||
|
||||
@@ -9,20 +9,14 @@ from sqlalchemy import delete
|
||||
|
||||
from letta import create_client
|
||||
from letta.functions.function_sets.base import core_memory_append, core_memory_replace
|
||||
from letta.orm import SandboxConfig, SandboxEnvironmentVariable
|
||||
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ChatMemory
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.sandbox_config import (
|
||||
E2BSandboxConfig,
|
||||
LocalSandboxConfig,
|
||||
SandboxConfigCreate,
|
||||
SandboxConfigUpdate,
|
||||
SandboxEnvironmentVariableCreate,
|
||||
SandboxType,
|
||||
)
|
||||
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate, SandboxType
|
||||
from letta.schemas.tool import Tool, ToolCreate
|
||||
from letta.schemas.user import User
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
|
||||
@@ -249,8 +249,8 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent
|
||||
send_message_ran = False
|
||||
# 3. Check that we get all the start/stop/end tokens we want
|
||||
# This includes all of the MessageStreamStatus enums
|
||||
done_gen = False
|
||||
done_step = False
|
||||
# done_gen = False
|
||||
# done_step = False
|
||||
done = False
|
||||
|
||||
# print(response)
|
||||
@@ -266,12 +266,12 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent
|
||||
if chunk == MessageStreamStatus.done:
|
||||
assert not done, "Message stream already done"
|
||||
done = True
|
||||
elif chunk == MessageStreamStatus.done_step:
|
||||
assert not done_step, "Message stream already done step"
|
||||
done_step = True
|
||||
elif chunk == MessageStreamStatus.done_generation:
|
||||
assert not done_gen, "Message stream already done generation"
|
||||
done_gen = True
|
||||
# elif chunk == MessageStreamStatus.done_step:
|
||||
# assert not done_step, "Message stream already done step"
|
||||
# done_step = True
|
||||
# elif chunk == MessageStreamStatus.done_generation:
|
||||
# assert not done_gen, "Message stream already done generation"
|
||||
# done_gen = True
|
||||
if isinstance(chunk, LettaUsageStatistics):
|
||||
# Some rough metrics for a reasonable usage pattern
|
||||
assert chunk.step_count == 1
|
||||
@@ -284,8 +284,8 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent
|
||||
assert inner_thoughts_exist, "No inner thoughts found"
|
||||
assert send_message_ran, "send_message function call not found"
|
||||
assert done, "Message stream not done"
|
||||
assert done_step, "Message stream not done step"
|
||||
assert done_gen, "Message stream not done generation"
|
||||
# assert done_step, "Message stream not done step"
|
||||
# assert done_gen, "Message stream not done generation"
|
||||
|
||||
|
||||
def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
|
||||
@@ -35,6 +35,7 @@ from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.block import BlockUpdate, CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import JobStatus, MessageRole
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
|
||||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from letta.schemas.job import Job as PydanticJob
|
||||
from letta.schemas.job import JobUpdate
|
||||
@@ -43,15 +44,7 @@ from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageCreate, MessageUpdate
|
||||
from letta.schemas.organization import Organization as PydanticOrganization
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.sandbox_config import (
|
||||
E2BSandboxConfig,
|
||||
LocalSandboxConfig,
|
||||
SandboxConfigCreate,
|
||||
SandboxConfigUpdate,
|
||||
SandboxEnvironmentVariableCreate,
|
||||
SandboxEnvironmentVariableUpdate,
|
||||
SandboxType,
|
||||
)
|
||||
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate, SandboxType
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.source import SourceUpdate
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
@@ -413,6 +406,7 @@ def comprehensive_test_agent_fixture(server: SyncServer, default_user, print_too
|
||||
metadata_={"test_key": "test_value"},
|
||||
tool_rules=[InitToolRule(tool_name=print_tool.name)],
|
||||
initial_message_sequence=[MessageCreate(role=MessageRole.user, text="hello world")],
|
||||
tool_exec_environment_variables={"test_env_var_key_a": "test_env_var_value_a", "test_env_var_key_b": "test_env_var_value_b"},
|
||||
)
|
||||
created_agent = server.agent_manager.create_agent(
|
||||
create_agent_request,
|
||||
@@ -482,20 +476,20 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent):
|
||||
def test_create_get_list_agent(server: SyncServer, comprehensive_test_agent_fixture, default_user):
|
||||
# Test agent creation
|
||||
created_agent, create_agent_request = comprehensive_test_agent_fixture
|
||||
comprehensive_agent_checks(created_agent, create_agent_request)
|
||||
comprehensive_agent_checks(created_agent, create_agent_request, actor=default_user)
|
||||
|
||||
# Test get agent
|
||||
get_agent = server.agent_manager.get_agent_by_id(agent_id=created_agent.id, actor=default_user)
|
||||
comprehensive_agent_checks(get_agent, create_agent_request)
|
||||
comprehensive_agent_checks(get_agent, create_agent_request, actor=default_user)
|
||||
|
||||
# Test get agent name
|
||||
get_agent_name = server.agent_manager.get_agent_by_name(agent_name=created_agent.name, actor=default_user)
|
||||
comprehensive_agent_checks(get_agent_name, create_agent_request)
|
||||
comprehensive_agent_checks(get_agent_name, create_agent_request, actor=default_user)
|
||||
|
||||
# Test list agent
|
||||
list_agents = server.agent_manager.list_agents(actor=default_user)
|
||||
assert len(list_agents) == 1
|
||||
comprehensive_agent_checks(list_agents[0], create_agent_request)
|
||||
comprehensive_agent_checks(list_agents[0], create_agent_request, actor=default_user)
|
||||
|
||||
# Test deleting the agent
|
||||
server.agent_manager.delete_agent(get_agent.id, default_user)
|
||||
@@ -566,10 +560,11 @@ def test_update_agent(server: SyncServer, comprehensive_test_agent_fixture, othe
|
||||
embedding_config=EmbeddingConfig.default_config(model_name="letta"),
|
||||
message_ids=["10", "20"],
|
||||
metadata_={"train_key": "train_value"},
|
||||
tool_exec_environment_variables={"new_tool_exec_key": "new_tool_exec_value"},
|
||||
)
|
||||
|
||||
updated_agent = server.agent_manager.update_agent(agent.id, update_agent_request, actor=default_user)
|
||||
comprehensive_agent_checks(updated_agent, update_agent_request)
|
||||
comprehensive_agent_checks(updated_agent, update_agent_request, actor=default_user)
|
||||
assert updated_agent.message_ids == update_agent_request.message_ids
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user