feat(asyncify): convert get composio key and sandbox config manager (#2436)

This commit is contained in:
cthomas
2025-05-26 13:25:33 -07:00
committed by GitHub
parent fb51e226ab
commit b1f38779cd
5 changed files with 197 additions and 50 deletions

View File

@@ -20,3 +20,19 @@ def get_composio_api_key(actor: User, logger: Optional[Logger] = None) -> Option
# Ideally, not tied to a specific sandbox, but for now we just get the first one
# Theoretically possible for someone to have different composio api keys per sandbox
return api_keys[0].value
async def get_composio_api_key_async(actor: User, logger: Optional[Logger] = None) -> Optional[str]:
api_keys = await SandboxConfigManager().list_sandbox_env_vars_by_key_async(key="COMPOSIO_API_KEY", actor=actor)
if not api_keys:
if logger:
logger.debug(f"No API keys found for Composio. Defaulting to the environment variable...")
if tool_settings.composio_api_key:
return tool_settings.composio_api_key
else:
return None
else:
# TODO: Add more protections around this
# Ideally, not tied to a specific sandbox, but for now we just get the first one
# Theoretically possible for someone to have different composio api keys per sandbox
return api_keys[0].value

View File

@@ -90,13 +90,13 @@ async def update_sandbox_config(
@router.delete("/{sandbox_config_id}", status_code=204)
def delete_sandbox_config(
async def delete_sandbox_config(
sandbox_config_id: str,
server: SyncServer = Depends(get_letta_server),
actor_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=actor_id)
server.sandbox_config_manager.delete_sandbox_config(sandbox_config_id, actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
await server.sandbox_config_manager.delete_sandbox_config_async(sandbox_config_id, actor)
@router.get("/", response_model=List[PydanticSandboxConfig])
@@ -158,35 +158,35 @@ async def force_recreate_local_sandbox_venv(
@router.post("/{sandbox_config_id}/environment-variable", response_model=PydanticEnvVar)
def create_sandbox_env_var(
async def create_sandbox_env_var(
sandbox_config_id: str,
env_var_create: SandboxEnvironmentVariableCreate,
server: SyncServer = Depends(get_letta_server),
actor_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.sandbox_config_manager.create_sandbox_env_var(env_var_create, sandbox_config_id, actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.sandbox_config_manager.create_sandbox_env_var_async(env_var_create, sandbox_config_id, actor)
@router.patch("/environment-variable/{env_var_id}", response_model=PydanticEnvVar)
def update_sandbox_env_var(
async def update_sandbox_env_var(
env_var_id: str,
env_var_update: SandboxEnvironmentVariableUpdate,
server: SyncServer = Depends(get_letta_server),
actor_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.sandbox_config_manager.update_sandbox_env_var(env_var_id, env_var_update, actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.sandbox_config_manager.update_sandbox_env_var_async(env_var_id, env_var_update, actor)
@router.delete("/environment-variable/{env_var_id}", status_code=204)
def delete_sandbox_env_var(
async def delete_sandbox_env_var(
env_var_id: str,
server: SyncServer = Depends(get_letta_server),
actor_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=actor_id)
server.sandbox_config_manager.delete_sandbox_env_var(env_var_id, actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
await server.sandbox_config_manager.delete_sandbox_env_var_async(env_var_id, actor)
@router.get("/{sandbox_config_id}/environment-variable", response_model=List[PydanticEnvVar])

View File

@@ -192,6 +192,15 @@ class SandboxConfigManager:
sandbox.hard_delete(db_session=session, actor=actor)
return sandbox.to_pydantic()
@enforce_types
@trace_method
async def delete_sandbox_config_async(self, sandbox_config_id: str, actor: PydanticUser) -> PydanticSandboxConfig:
"""Delete a sandbox configuration by its ID."""
async with db_registry.async_session() as session:
sandbox = await SandboxConfigModel.read_async(db_session=session, identifier=sandbox_config_id, actor=actor)
await sandbox.hard_delete_async(db_session=session, actor=actor)
return sandbox.to_pydantic()
@enforce_types
@trace_method
def list_sandbox_configs(
@@ -305,6 +314,34 @@ class SandboxConfigManager:
env_var.create(session, actor=actor)
return env_var.to_pydantic()
@enforce_types
@trace_method
async def create_sandbox_env_var_async(
self, env_var_create: SandboxEnvironmentVariableCreate, sandbox_config_id: str, actor: PydanticUser
) -> PydanticEnvVar:
"""Create a new sandbox environment variable."""
env_var = PydanticEnvVar(**env_var_create.model_dump(), sandbox_config_id=sandbox_config_id, organization_id=actor.organization_id)
db_env_var = await self.get_sandbox_env_var_by_key_and_sandbox_config_id_async(env_var.key, env_var.sandbox_config_id, actor=actor)
if db_env_var:
update_data = env_var.model_dump(exclude_unset=True, exclude_none=True)
update_data = {key: value for key, value in update_data.items() if getattr(db_env_var, key) != value}
# If there are changes, update the environment variable
if update_data:
db_env_var = await self.update_sandbox_env_var_async(db_env_var.id, SandboxEnvironmentVariableUpdate(**update_data), actor)
else:
printd(
f"`create_or_update_sandbox_env_var` was called with user_id={actor.id}, organization_id={actor.organization_id}, "
f"key={env_var.key}, but found existing variable with nothing to update."
)
return db_env_var
else:
async with db_registry.async_session() as session:
env_var = SandboxEnvVarModel(**env_var.model_dump(to_orm=True, exclude_none=True))
await env_var.create_async(session, actor=actor)
return env_var.to_pydantic()
@enforce_types
@trace_method
def update_sandbox_env_var(
@@ -327,6 +364,28 @@ class SandboxConfigManager:
)
return env_var.to_pydantic()
@enforce_types
@trace_method
async def update_sandbox_env_var_async(
self, env_var_id: str, env_var_update: SandboxEnvironmentVariableUpdate, actor: PydanticUser
) -> PydanticEnvVar:
"""Update an existing sandbox environment variable."""
async with db_registry.async_session() as session:
env_var = await SandboxEnvVarModel.read_async(db_session=session, identifier=env_var_id, actor=actor)
update_data = env_var_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
update_data = {key: value for key, value in update_data.items() if getattr(env_var, key) != value}
if update_data:
for key, value in update_data.items():
setattr(env_var, key, value)
await env_var.update_async(db_session=session, actor=actor)
else:
printd(
f"`update_sandbox_env_var` called with user_id={actor.id}, organization_id={actor.organization_id}, "
f"key={env_var.key}, but nothing to update."
)
return env_var.to_pydantic()
@enforce_types
@trace_method
def delete_sandbox_env_var(self, env_var_id: str, actor: PydanticUser) -> PydanticEnvVar:
@@ -336,6 +395,15 @@ class SandboxConfigManager:
env_var.hard_delete(db_session=session, actor=actor)
return env_var.to_pydantic()
@enforce_types
@trace_method
async def delete_sandbox_env_var_async(self, env_var_id: str, actor: PydanticUser) -> PydanticEnvVar:
"""Delete a sandbox environment variable by its ID."""
async with db_registry.async_session() as session:
env_var = await SandboxEnvVarModel.read_async(db_session=session, identifier=env_var_id, actor=actor)
await env_var.hard_delete_async(db_session=session, actor=actor)
return env_var.to_pydantic()
@enforce_types
@trace_method
def list_sandbox_env_vars(
@@ -392,6 +460,22 @@ class SandboxConfigManager:
)
return [env_var.to_pydantic() for env_var in env_vars]
@enforce_types
@trace_method
async def list_sandbox_env_vars_by_key_async(
self, key: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50
) -> List[PydanticEnvVar]:
"""List all sandbox environment variables with optional pagination."""
async with db_registry.async_session() as session:
env_vars = await SandboxEnvVarModel.list_async(
db_session=session,
after=after,
limit=limit,
organization_id=actor.organization_id,
key=key,
)
return [env_var.to_pydantic() for env_var in env_vars]
@enforce_types
@trace_method
def get_sandbox_env_vars_as_dict(
@@ -434,3 +518,24 @@ class SandboxConfigManager:
return None
except NoResultFound:
return None
@enforce_types
@trace_method
async def get_sandbox_env_var_by_key_and_sandbox_config_id_async(
self, key: str, sandbox_config_id: str, actor: Optional[PydanticUser] = None
) -> Optional[PydanticEnvVar]:
"""Retrieve a sandbox environment variable by its key and sandbox_config_id."""
async with db_registry.async_session() as session:
try:
env_var = await SandboxEnvVarModel.list_async(
db_session=session,
key=key,
sandbox_config_id=sandbox_config_id,
organization_id=actor.organization_id,
limit=1,
)
if env_var:
return env_var[0].to_pydantic()
return None
except NoResultFound:
return None

View File

@@ -19,7 +19,7 @@ from letta.constants import (
)
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
from letta.functions.composio_helpers import execute_composio_action_async, generate_composio_action_from_func_name
from letta.helpers.composio_helpers import get_composio_api_key
from letta.helpers.composio_helpers import get_composio_api_key_async
from letta.helpers.json_helpers import json_dumps
from letta.log import get_logger
from letta.schemas.agent import AgentState
@@ -656,7 +656,7 @@ class ExternalComposioToolExecutor(ToolExecutor):
entity_id = self._get_entity_id(agent_state)
# Get composio_api_key
composio_api_key = get_composio_api_key(actor=actor)
composio_api_key = await get_composio_api_key_async(actor=actor)
# TODO (matt): Roll in execute_composio_action into this class
function_response = await execute_composio_action_async(

View File

@@ -4083,11 +4083,12 @@ async def test_delete_file(server: SyncServer, default_user, default_source):
# ======================================================================================================================
def test_create_or_update_sandbox_config(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_create_or_update_sandbox_config(server: SyncServer, default_user, event_loop):
sandbox_config_create = SandboxConfigCreate(
config=E2BSandboxConfig(),
)
created_config = server.sandbox_config_manager.create_or_update_sandbox_config(sandbox_config_create, actor=default_user)
created_config = await server.sandbox_config_manager.create_or_update_sandbox_config_async(sandbox_config_create, actor=default_user)
# Assertions
assert created_config.type == SandboxType.E2B
@@ -4095,11 +4096,12 @@ def test_create_or_update_sandbox_config(server: SyncServer, default_user):
assert created_config.organization_id == default_user.organization_id
def test_create_local_sandbox_config_defaults(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_create_local_sandbox_config_defaults(server: SyncServer, default_user, event_loop):
sandbox_config_create = SandboxConfigCreate(
config=LocalSandboxConfig(),
)
created_config = server.sandbox_config_manager.create_or_update_sandbox_config(sandbox_config_create, actor=default_user)
created_config = await server.sandbox_config_manager.create_or_update_sandbox_config_async(sandbox_config_create, actor=default_user)
# Assertions
assert created_config.type == SandboxType.LOCAL
@@ -4108,8 +4110,11 @@ def test_create_local_sandbox_config_defaults(server: SyncServer, default_user):
assert created_config.organization_id == default_user.organization_id
def test_default_e2b_settings_sandbox_config(server: SyncServer, default_user):
created_config = server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=default_user)
@pytest.mark.asyncio
async def test_default_e2b_settings_sandbox_config(server: SyncServer, default_user, event_loop):
created_config = await server.sandbox_config_manager.get_or_create_default_sandbox_config_async(
sandbox_type=SandboxType.E2B, actor=default_user
)
e2b_config = created_config.get_e2b_config()
# Assertions
@@ -4117,35 +4122,41 @@ def test_default_e2b_settings_sandbox_config(server: SyncServer, default_user):
assert e2b_config.template == tool_settings.e2b_sandbox_template_id
def test_update_existing_sandbox_config(server: SyncServer, sandbox_config_fixture, default_user):
@pytest.mark.asyncio
async def test_update_existing_sandbox_config(server: SyncServer, sandbox_config_fixture, default_user, event_loop):
update_data = SandboxConfigUpdate(config=E2BSandboxConfig(template="template_2", timeout=120))
updated_config = server.sandbox_config_manager.update_sandbox_config(sandbox_config_fixture.id, update_data, actor=default_user)
updated_config = await server.sandbox_config_manager.update_sandbox_config_async(
sandbox_config_fixture.id, update_data, actor=default_user
)
# Assertions
assert updated_config.config["template"] == "template_2"
assert updated_config.config["timeout"] == 120
def test_delete_sandbox_config(server: SyncServer, sandbox_config_fixture, default_user):
deleted_config = server.sandbox_config_manager.delete_sandbox_config(sandbox_config_fixture.id, actor=default_user)
@pytest.mark.asyncio
async def test_delete_sandbox_config(server: SyncServer, sandbox_config_fixture, default_user, event_loop):
deleted_config = await server.sandbox_config_manager.delete_sandbox_config_async(sandbox_config_fixture.id, actor=default_user)
# Assertions to verify deletion
assert deleted_config.id == sandbox_config_fixture.id
# Verify it no longer exists
config_list = server.sandbox_config_manager.list_sandbox_configs(actor=default_user)
config_list = await server.sandbox_config_manager.list_sandbox_configs_async(actor=default_user)
assert sandbox_config_fixture.id not in [config.id for config in config_list]
def test_get_sandbox_config_by_type(server: SyncServer, sandbox_config_fixture, default_user):
retrieved_config = server.sandbox_config_manager.get_sandbox_config_by_type(sandbox_config_fixture.type, actor=default_user)
@pytest.mark.asyncio
async def test_get_sandbox_config_by_type(server: SyncServer, sandbox_config_fixture, default_user, event_loop):
retrieved_config = await server.sandbox_config_manager.get_sandbox_config_by_type_async(sandbox_config_fixture.type, actor=default_user)
# Assertions to verify correct retrieval
assert retrieved_config.id == sandbox_config_fixture.id
assert retrieved_config.type == sandbox_config_fixture.type
def test_list_sandbox_configs(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_list_sandbox_configs(server: SyncServer, default_user, event_loop):
# Creating multiple sandbox configs
config_e2b_create = SandboxConfigCreate(
config=E2BSandboxConfig(),
@@ -4153,29 +4164,29 @@ def test_list_sandbox_configs(server: SyncServer, default_user):
config_local_create = SandboxConfigCreate(
config=LocalSandboxConfig(sandbox_dir=""),
)
config_e2b = server.sandbox_config_manager.create_or_update_sandbox_config(config_e2b_create, actor=default_user)
config_e2b = await server.sandbox_config_manager.create_or_update_sandbox_config_async(config_e2b_create, actor=default_user)
if USING_SQLITE:
time.sleep(CREATE_DELAY_SQLITE)
config_local = server.sandbox_config_manager.create_or_update_sandbox_config(config_local_create, actor=default_user)
config_local = await server.sandbox_config_manager.create_or_update_sandbox_config_async(config_local_create, actor=default_user)
# List configs without pagination
configs = server.sandbox_config_manager.list_sandbox_configs(actor=default_user)
configs = await server.sandbox_config_manager.list_sandbox_configs_async(actor=default_user)
assert len(configs) >= 2
# List configs with pagination
paginated_configs = server.sandbox_config_manager.list_sandbox_configs(actor=default_user, limit=1)
paginated_configs = await server.sandbox_config_manager.list_sandbox_configs_async(actor=default_user, limit=1)
assert len(paginated_configs) == 1
next_page = server.sandbox_config_manager.list_sandbox_configs(actor=default_user, after=paginated_configs[-1].id, limit=1)
next_page = await server.sandbox_config_manager.list_sandbox_configs_async(actor=default_user, after=paginated_configs[-1].id, limit=1)
assert len(next_page) == 1
assert next_page[0].id != paginated_configs[0].id
# List configs using sandbox_type filter
configs = server.sandbox_config_manager.list_sandbox_configs(actor=default_user, sandbox_type=SandboxType.E2B)
configs = await server.sandbox_config_manager.list_sandbox_configs_async(actor=default_user, sandbox_type=SandboxType.E2B)
assert len(configs) == 1
assert configs[0].id == config_e2b.id
configs = server.sandbox_config_manager.list_sandbox_configs(actor=default_user, sandbox_type=SandboxType.LOCAL)
configs = await server.sandbox_config_manager.list_sandbox_configs_async(actor=default_user, sandbox_type=SandboxType.LOCAL)
assert len(configs) == 1
assert configs[0].id == config_local.id
@@ -4185,9 +4196,10 @@ def test_list_sandbox_configs(server: SyncServer, default_user):
# ======================================================================================================================
def test_create_sandbox_env_var(server: SyncServer, sandbox_config_fixture, default_user):
@pytest.mark.asyncio
async def test_create_sandbox_env_var(server: SyncServer, sandbox_config_fixture, default_user, event_loop):
env_var_create = SandboxEnvironmentVariableCreate(key="TEST_VAR", value="test_value", description="A test environment variable.")
created_env_var = server.sandbox_config_manager.create_sandbox_env_var(
created_env_var = await server.sandbox_config_manager.create_sandbox_env_var_async(
env_var_create, sandbox_config_id=sandbox_config_fixture.id, actor=default_user
)
@@ -4197,54 +4209,68 @@ def test_create_sandbox_env_var(server: SyncServer, sandbox_config_fixture, defa
assert created_env_var.organization_id == default_user.organization_id
def test_update_sandbox_env_var(server: SyncServer, sandbox_env_var_fixture, default_user):
@pytest.mark.asyncio
async def test_update_sandbox_env_var(server: SyncServer, sandbox_env_var_fixture, default_user, event_loop):
update_data = SandboxEnvironmentVariableUpdate(value="updated_value")
updated_env_var = server.sandbox_config_manager.update_sandbox_env_var(sandbox_env_var_fixture.id, update_data, actor=default_user)
updated_env_var = await server.sandbox_config_manager.update_sandbox_env_var_async(
sandbox_env_var_fixture.id, update_data, actor=default_user
)
# Assertions
assert updated_env_var.value == "updated_value"
assert updated_env_var.id == sandbox_env_var_fixture.id
def test_delete_sandbox_env_var(server: SyncServer, sandbox_config_fixture, sandbox_env_var_fixture, default_user):
deleted_env_var = server.sandbox_config_manager.delete_sandbox_env_var(sandbox_env_var_fixture.id, actor=default_user)
@pytest.mark.asyncio
async def test_delete_sandbox_env_var(server: SyncServer, sandbox_config_fixture, sandbox_env_var_fixture, default_user, event_loop):
deleted_env_var = await server.sandbox_config_manager.delete_sandbox_env_var_async(sandbox_env_var_fixture.id, actor=default_user)
# Assertions to verify deletion
assert deleted_env_var.id == sandbox_env_var_fixture.id
# Verify it no longer exists
env_vars = server.sandbox_config_manager.list_sandbox_env_vars(sandbox_config_id=sandbox_config_fixture.id, actor=default_user)
env_vars = await server.sandbox_config_manager.list_sandbox_env_vars_async(
sandbox_config_id=sandbox_config_fixture.id, actor=default_user
)
assert sandbox_env_var_fixture.id not in [env_var.id for env_var in env_vars]
def test_list_sandbox_env_vars(server: SyncServer, sandbox_config_fixture, default_user):
@pytest.mark.asyncio
async def test_list_sandbox_env_vars(server: SyncServer, sandbox_config_fixture, default_user, event_loop):
# Creating multiple environment variables
env_var_create_a = SandboxEnvironmentVariableCreate(key="VAR1", value="value1")
env_var_create_b = SandboxEnvironmentVariableCreate(key="VAR2", value="value2")
server.sandbox_config_manager.create_sandbox_env_var(env_var_create_a, sandbox_config_id=sandbox_config_fixture.id, actor=default_user)
await server.sandbox_config_manager.create_sandbox_env_var_async(
env_var_create_a, sandbox_config_id=sandbox_config_fixture.id, actor=default_user
)
if USING_SQLITE:
time.sleep(CREATE_DELAY_SQLITE)
server.sandbox_config_manager.create_sandbox_env_var(env_var_create_b, sandbox_config_id=sandbox_config_fixture.id, actor=default_user)
await server.sandbox_config_manager.create_sandbox_env_var_async(
env_var_create_b, sandbox_config_id=sandbox_config_fixture.id, actor=default_user
)
# List env vars without pagination
env_vars = server.sandbox_config_manager.list_sandbox_env_vars(sandbox_config_id=sandbox_config_fixture.id, actor=default_user)
env_vars = await server.sandbox_config_manager.list_sandbox_env_vars_async(
sandbox_config_id=sandbox_config_fixture.id, actor=default_user
)
assert len(env_vars) >= 2
# List env vars with pagination
paginated_env_vars = server.sandbox_config_manager.list_sandbox_env_vars(
paginated_env_vars = await server.sandbox_config_manager.list_sandbox_env_vars_async(
sandbox_config_id=sandbox_config_fixture.id, actor=default_user, limit=1
)
assert len(paginated_env_vars) == 1
next_page = server.sandbox_config_manager.list_sandbox_env_vars(
next_page = await server.sandbox_config_manager.list_sandbox_env_vars_async(
sandbox_config_id=sandbox_config_fixture.id, actor=default_user, after=paginated_env_vars[-1].id, limit=1
)
assert len(next_page) == 1
assert next_page[0].id != paginated_env_vars[0].id
def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture, default_user):
retrieved_env_var = server.sandbox_config_manager.get_sandbox_env_var_by_key_and_sandbox_config_id(
@pytest.mark.asyncio
async def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture, default_user, event_loop):
retrieved_env_var = await server.sandbox_config_manager.get_sandbox_env_var_by_key_and_sandbox_config_id_async(
sandbox_env_var_fixture.key, sandbox_env_var_fixture.sandbox_config_id, actor=default_user
)