diff --git a/letta/helpers/composio_helpers.py b/letta/helpers/composio_helpers.py index 2a0281e1..7a142f7c 100644 --- a/letta/helpers/composio_helpers.py +++ b/letta/helpers/composio_helpers.py @@ -20,3 +20,19 @@ def get_composio_api_key(actor: User, logger: Optional[Logger] = None) -> Option # Ideally, not tied to a specific sandbox, but for now we just get the first one # Theoretically possible for someone to have different composio api keys per sandbox return api_keys[0].value + + +async def get_composio_api_key_async(actor: User, logger: Optional[Logger] = None) -> Optional[str]: + api_keys = await SandboxConfigManager().list_sandbox_env_vars_by_key_async(key="COMPOSIO_API_KEY", actor=actor) + if not api_keys: + if logger: + logger.debug(f"No API keys found for Composio. Defaulting to the environment variable...") + if tool_settings.composio_api_key: + return tool_settings.composio_api_key + else: + return None + else: + # TODO: Add more protections around this + # Ideally, not tied to a specific sandbox, but for now we just get the first one + # Theoretically possible for someone to have different composio api keys per sandbox + return api_keys[0].value diff --git a/letta/server/rest_api/routers/v1/sandbox_configs.py b/letta/server/rest_api/routers/v1/sandbox_configs.py index 00681ea2..d82e9fff 100644 --- a/letta/server/rest_api/routers/v1/sandbox_configs.py +++ b/letta/server/rest_api/routers/v1/sandbox_configs.py @@ -90,13 +90,13 @@ async def update_sandbox_config( @router.delete("/{sandbox_config_id}", status_code=204) -def delete_sandbox_config( +async def delete_sandbox_config( sandbox_config_id: str, server: SyncServer = Depends(get_letta_server), actor_id: str = Depends(get_user_id), ): - actor = server.user_manager.get_user_or_default(user_id=actor_id) - server.sandbox_config_manager.delete_sandbox_config(sandbox_config_id, actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + await server.sandbox_config_manager.delete_sandbox_config_async(sandbox_config_id, actor) @router.get("/", response_model=List[PydanticSandboxConfig]) @@ -158,35 +158,35 @@ async def force_recreate_local_sandbox_venv( @router.post("/{sandbox_config_id}/environment-variable", response_model=PydanticEnvVar) -def create_sandbox_env_var( +async def create_sandbox_env_var( sandbox_config_id: str, env_var_create: SandboxEnvironmentVariableCreate, server: SyncServer = Depends(get_letta_server), actor_id: str = Depends(get_user_id), ): - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.sandbox_config_manager.create_sandbox_env_var(env_var_create, sandbox_config_id, actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.sandbox_config_manager.create_sandbox_env_var_async(env_var_create, sandbox_config_id, actor) @router.patch("/environment-variable/{env_var_id}", response_model=PydanticEnvVar) -def update_sandbox_env_var( +async def update_sandbox_env_var( env_var_id: str, env_var_update: SandboxEnvironmentVariableUpdate, server: SyncServer = Depends(get_letta_server), actor_id: str = Depends(get_user_id), ): - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.sandbox_config_manager.update_sandbox_env_var(env_var_id, env_var_update, actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.sandbox_config_manager.update_sandbox_env_var_async(env_var_id, env_var_update, actor) @router.delete("/environment-variable/{env_var_id}", status_code=204) -def delete_sandbox_env_var( +async def delete_sandbox_env_var( env_var_id: str, server: SyncServer = Depends(get_letta_server), actor_id: str = Depends(get_user_id), ): - actor = server.user_manager.get_user_or_default(user_id=actor_id) - server.sandbox_config_manager.delete_sandbox_env_var(env_var_id, actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + await server.sandbox_config_manager.delete_sandbox_env_var_async(env_var_id, actor) @router.get("/{sandbox_config_id}/environment-variable", response_model=List[PydanticEnvVar]) diff --git a/letta/services/sandbox_config_manager.py b/letta/services/sandbox_config_manager.py index 6e7a43bc..a8d17fb0 100644 --- a/letta/services/sandbox_config_manager.py +++ b/letta/services/sandbox_config_manager.py @@ -192,6 +192,15 @@ class SandboxConfigManager: sandbox.hard_delete(db_session=session, actor=actor) return sandbox.to_pydantic() + @enforce_types + @trace_method + async def delete_sandbox_config_async(self, sandbox_config_id: str, actor: PydanticUser) -> PydanticSandboxConfig: + """Delete a sandbox configuration by its ID.""" + async with db_registry.async_session() as session: + sandbox = await SandboxConfigModel.read_async(db_session=session, identifier=sandbox_config_id, actor=actor) + await sandbox.hard_delete_async(db_session=session, actor=actor) + return sandbox.to_pydantic() + @enforce_types @trace_method def list_sandbox_configs( @@ -305,6 +314,34 @@ class SandboxConfigManager: env_var.create(session, actor=actor) return env_var.to_pydantic() + @enforce_types + @trace_method + async def create_sandbox_env_var_async( + self, env_var_create: SandboxEnvironmentVariableCreate, sandbox_config_id: str, actor: PydanticUser + ) -> PydanticEnvVar: + """Create a new sandbox environment variable.""" + env_var = PydanticEnvVar(**env_var_create.model_dump(), sandbox_config_id=sandbox_config_id, organization_id=actor.organization_id) + + db_env_var = await self.get_sandbox_env_var_by_key_and_sandbox_config_id_async(env_var.key, env_var.sandbox_config_id, actor=actor) + if db_env_var: + update_data = env_var.model_dump(exclude_unset=True, exclude_none=True) + update_data = {key: value for key, value in update_data.items() if getattr(db_env_var, key) != value} + # If there are changes, update the environment variable + if update_data: + db_env_var = await self.update_sandbox_env_var_async(db_env_var.id, SandboxEnvironmentVariableUpdate(**update_data), actor) + else: + printd( + f"`create_or_update_sandbox_env_var` was called with user_id={actor.id}, organization_id={actor.organization_id}, " + f"key={env_var.key}, but found existing variable with nothing to update." + ) + + return db_env_var + else: + async with db_registry.async_session() as session: + env_var = SandboxEnvVarModel(**env_var.model_dump(to_orm=True, exclude_none=True)) + await env_var.create_async(session, actor=actor) + return env_var.to_pydantic() + @enforce_types @trace_method def update_sandbox_env_var( @@ -327,6 +364,28 @@ class SandboxConfigManager: ) return env_var.to_pydantic() + @enforce_types + @trace_method + async def update_sandbox_env_var_async( + self, env_var_id: str, env_var_update: SandboxEnvironmentVariableUpdate, actor: PydanticUser + ) -> PydanticEnvVar: + """Update an existing sandbox environment variable.""" + async with db_registry.async_session() as session: + env_var = await SandboxEnvVarModel.read_async(db_session=session, identifier=env_var_id, actor=actor) + update_data = env_var_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) + update_data = {key: value for key, value in update_data.items() if getattr(env_var, key) != value} + + if update_data: + for key, value in update_data.items(): + setattr(env_var, key, value) + await env_var.update_async(db_session=session, actor=actor) + else: + printd( + f"`update_sandbox_env_var` called with user_id={actor.id}, organization_id={actor.organization_id}, " + f"key={env_var.key}, but nothing to update." + ) + return env_var.to_pydantic() + @enforce_types @trace_method def delete_sandbox_env_var(self, env_var_id: str, actor: PydanticUser) -> PydanticEnvVar: @@ -336,6 +395,15 @@ class SandboxConfigManager: env_var.hard_delete(db_session=session, actor=actor) return env_var.to_pydantic() + @enforce_types + @trace_method + async def delete_sandbox_env_var_async(self, env_var_id: str, actor: PydanticUser) -> PydanticEnvVar: + """Delete a sandbox environment variable by its ID.""" + async with db_registry.async_session() as session: + env_var = await SandboxEnvVarModel.read_async(db_session=session, identifier=env_var_id, actor=actor) + await env_var.hard_delete_async(db_session=session, actor=actor) + return env_var.to_pydantic() + @enforce_types @trace_method def list_sandbox_env_vars( @@ -392,6 +460,22 @@ class SandboxConfigManager: ) return [env_var.to_pydantic() for env_var in env_vars] + @enforce_types + @trace_method + async def list_sandbox_env_vars_by_key_async( + self, key: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50 + ) -> List[PydanticEnvVar]: + """List all sandbox environment variables with optional pagination.""" + async with db_registry.async_session() as session: + env_vars = await SandboxEnvVarModel.list_async( + db_session=session, + after=after, + limit=limit, + organization_id=actor.organization_id, + key=key, + ) + return [env_var.to_pydantic() for env_var in env_vars] + @enforce_types @trace_method def get_sandbox_env_vars_as_dict( @@ -434,3 +518,24 @@ class SandboxConfigManager: return None except NoResultFound: return None + + @enforce_types + @trace_method + async def get_sandbox_env_var_by_key_and_sandbox_config_id_async( + self, key: str, sandbox_config_id: str, actor: Optional[PydanticUser] = None + ) -> Optional[PydanticEnvVar]: + """Retrieve a sandbox environment variable by its key and sandbox_config_id.""" + async with db_registry.async_session() as session: + try: + env_var = await SandboxEnvVarModel.list_async( + db_session=session, + key=key, + sandbox_config_id=sandbox_config_id, + organization_id=actor.organization_id, + limit=1, + ) + if env_var: + return env_var[0].to_pydantic() + return None + except NoResultFound: + return None diff --git a/letta/services/tool_executor/tool_executor.py b/letta/services/tool_executor/tool_executor.py index 0a205ebe..104a4446 100644 --- a/letta/services/tool_executor/tool_executor.py +++ b/letta/services/tool_executor/tool_executor.py @@ -19,7 +19,7 @@ from letta.constants import ( ) from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source from letta.functions.composio_helpers import execute_composio_action_async, generate_composio_action_from_func_name -from letta.helpers.composio_helpers import get_composio_api_key +from letta.helpers.composio_helpers import get_composio_api_key_async from letta.helpers.json_helpers import json_dumps from letta.log import get_logger from letta.schemas.agent import AgentState @@ -656,7 +656,7 @@ class ExternalComposioToolExecutor(ToolExecutor): entity_id = self._get_entity_id(agent_state) # Get composio_api_key - composio_api_key = get_composio_api_key(actor=actor) + composio_api_key = await get_composio_api_key_async(actor=actor) # TODO (matt): Roll in execute_composio_action into this class function_response = await execute_composio_action_async( diff --git a/tests/test_managers.py b/tests/test_managers.py index 146872c9..c9f578e6 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -4083,11 +4083,12 @@ async def test_delete_file(server: SyncServer, default_user, default_source): # ====================================================================================================================== -def test_create_or_update_sandbox_config(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_create_or_update_sandbox_config(server: SyncServer, default_user, event_loop): sandbox_config_create = SandboxConfigCreate( config=E2BSandboxConfig(), ) - created_config = server.sandbox_config_manager.create_or_update_sandbox_config(sandbox_config_create, actor=default_user) + created_config = await server.sandbox_config_manager.create_or_update_sandbox_config_async(sandbox_config_create, actor=default_user) # Assertions assert created_config.type == SandboxType.E2B @@ -4095,11 +4096,12 @@ def test_create_or_update_sandbox_config(server: SyncServer, default_user): assert created_config.organization_id == default_user.organization_id -def test_create_local_sandbox_config_defaults(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_create_local_sandbox_config_defaults(server: SyncServer, default_user, event_loop): sandbox_config_create = SandboxConfigCreate( config=LocalSandboxConfig(), ) - created_config = server.sandbox_config_manager.create_or_update_sandbox_config(sandbox_config_create, actor=default_user) + created_config = await server.sandbox_config_manager.create_or_update_sandbox_config_async(sandbox_config_create, actor=default_user) # Assertions assert created_config.type == SandboxType.LOCAL @@ -4108,8 +4110,11 @@ def test_create_local_sandbox_config_defaults(server: SyncServer, default_user): assert created_config.organization_id == default_user.organization_id -def test_default_e2b_settings_sandbox_config(server: SyncServer, default_user): - created_config = server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=default_user) +@pytest.mark.asyncio +async def test_default_e2b_settings_sandbox_config(server: SyncServer, default_user, event_loop): + created_config = await server.sandbox_config_manager.get_or_create_default_sandbox_config_async( + sandbox_type=SandboxType.E2B, actor=default_user + ) e2b_config = created_config.get_e2b_config() # Assertions @@ -4117,35 +4122,41 @@ def test_default_e2b_settings_sandbox_config(server: SyncServer, default_user): assert e2b_config.template == tool_settings.e2b_sandbox_template_id -def test_update_existing_sandbox_config(server: SyncServer, sandbox_config_fixture, default_user): +@pytest.mark.asyncio +async def test_update_existing_sandbox_config(server: SyncServer, sandbox_config_fixture, default_user, event_loop): update_data = SandboxConfigUpdate(config=E2BSandboxConfig(template="template_2", timeout=120)) - updated_config = server.sandbox_config_manager.update_sandbox_config(sandbox_config_fixture.id, update_data, actor=default_user) + updated_config = await server.sandbox_config_manager.update_sandbox_config_async( + sandbox_config_fixture.id, update_data, actor=default_user + ) # Assertions assert updated_config.config["template"] == "template_2" assert updated_config.config["timeout"] == 120 -def test_delete_sandbox_config(server: SyncServer, sandbox_config_fixture, default_user): - deleted_config = server.sandbox_config_manager.delete_sandbox_config(sandbox_config_fixture.id, actor=default_user) +@pytest.mark.asyncio +async def test_delete_sandbox_config(server: SyncServer, sandbox_config_fixture, default_user, event_loop): + deleted_config = await server.sandbox_config_manager.delete_sandbox_config_async(sandbox_config_fixture.id, actor=default_user) # Assertions to verify deletion assert deleted_config.id == sandbox_config_fixture.id # Verify it no longer exists - config_list = server.sandbox_config_manager.list_sandbox_configs(actor=default_user) + config_list = await server.sandbox_config_manager.list_sandbox_configs_async(actor=default_user) assert sandbox_config_fixture.id not in [config.id for config in config_list] -def test_get_sandbox_config_by_type(server: SyncServer, sandbox_config_fixture, default_user): - retrieved_config = server.sandbox_config_manager.get_sandbox_config_by_type(sandbox_config_fixture.type, actor=default_user) +@pytest.mark.asyncio +async def test_get_sandbox_config_by_type(server: SyncServer, sandbox_config_fixture, default_user, event_loop): + retrieved_config = await server.sandbox_config_manager.get_sandbox_config_by_type_async(sandbox_config_fixture.type, actor=default_user) # Assertions to verify correct retrieval assert retrieved_config.id == sandbox_config_fixture.id assert retrieved_config.type == sandbox_config_fixture.type -def test_list_sandbox_configs(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_list_sandbox_configs(server: SyncServer, default_user, event_loop): # Creating multiple sandbox configs config_e2b_create = SandboxConfigCreate( config=E2BSandboxConfig(), @@ -4153,29 +4164,29 @@ def test_list_sandbox_configs(server: SyncServer, default_user): config_local_create = SandboxConfigCreate( config=LocalSandboxConfig(sandbox_dir=""), ) - config_e2b = server.sandbox_config_manager.create_or_update_sandbox_config(config_e2b_create, actor=default_user) + config_e2b = await server.sandbox_config_manager.create_or_update_sandbox_config_async(config_e2b_create, actor=default_user) if USING_SQLITE: time.sleep(CREATE_DELAY_SQLITE) - config_local = server.sandbox_config_manager.create_or_update_sandbox_config(config_local_create, actor=default_user) + config_local = await server.sandbox_config_manager.create_or_update_sandbox_config_async(config_local_create, actor=default_user) # List configs without pagination - configs = server.sandbox_config_manager.list_sandbox_configs(actor=default_user) + configs = await server.sandbox_config_manager.list_sandbox_configs_async(actor=default_user) assert len(configs) >= 2 # List configs with pagination - paginated_configs = server.sandbox_config_manager.list_sandbox_configs(actor=default_user, limit=1) + paginated_configs = await server.sandbox_config_manager.list_sandbox_configs_async(actor=default_user, limit=1) assert len(paginated_configs) == 1 - next_page = server.sandbox_config_manager.list_sandbox_configs(actor=default_user, after=paginated_configs[-1].id, limit=1) + next_page = await server.sandbox_config_manager.list_sandbox_configs_async(actor=default_user, after=paginated_configs[-1].id, limit=1) assert len(next_page) == 1 assert next_page[0].id != paginated_configs[0].id # List configs using sandbox_type filter - configs = server.sandbox_config_manager.list_sandbox_configs(actor=default_user, sandbox_type=SandboxType.E2B) + configs = await server.sandbox_config_manager.list_sandbox_configs_async(actor=default_user, sandbox_type=SandboxType.E2B) assert len(configs) == 1 assert configs[0].id == config_e2b.id - configs = server.sandbox_config_manager.list_sandbox_configs(actor=default_user, sandbox_type=SandboxType.LOCAL) + configs = await server.sandbox_config_manager.list_sandbox_configs_async(actor=default_user, sandbox_type=SandboxType.LOCAL) assert len(configs) == 1 assert configs[0].id == config_local.id @@ -4185,9 +4196,10 @@ def test_list_sandbox_configs(server: SyncServer, default_user): # ====================================================================================================================== -def test_create_sandbox_env_var(server: SyncServer, sandbox_config_fixture, default_user): +@pytest.mark.asyncio +async def test_create_sandbox_env_var(server: SyncServer, sandbox_config_fixture, default_user, event_loop): env_var_create = SandboxEnvironmentVariableCreate(key="TEST_VAR", value="test_value", description="A test environment variable.") - created_env_var = server.sandbox_config_manager.create_sandbox_env_var( + created_env_var = await server.sandbox_config_manager.create_sandbox_env_var_async( env_var_create, sandbox_config_id=sandbox_config_fixture.id, actor=default_user ) @@ -4197,54 +4209,68 @@ def test_create_sandbox_env_var(server: SyncServer, sandbox_config_fixture, defa assert created_env_var.organization_id == default_user.organization_id -def test_update_sandbox_env_var(server: SyncServer, sandbox_env_var_fixture, default_user): +@pytest.mark.asyncio +async def test_update_sandbox_env_var(server: SyncServer, sandbox_env_var_fixture, default_user, event_loop): update_data = SandboxEnvironmentVariableUpdate(value="updated_value") - updated_env_var = server.sandbox_config_manager.update_sandbox_env_var(sandbox_env_var_fixture.id, update_data, actor=default_user) + updated_env_var = await server.sandbox_config_manager.update_sandbox_env_var_async( + sandbox_env_var_fixture.id, update_data, actor=default_user + ) # Assertions assert updated_env_var.value == "updated_value" assert updated_env_var.id == sandbox_env_var_fixture.id -def test_delete_sandbox_env_var(server: SyncServer, sandbox_config_fixture, sandbox_env_var_fixture, default_user): - deleted_env_var = server.sandbox_config_manager.delete_sandbox_env_var(sandbox_env_var_fixture.id, actor=default_user) +@pytest.mark.asyncio +async def test_delete_sandbox_env_var(server: SyncServer, sandbox_config_fixture, sandbox_env_var_fixture, default_user, event_loop): + deleted_env_var = await server.sandbox_config_manager.delete_sandbox_env_var_async(sandbox_env_var_fixture.id, actor=default_user) # Assertions to verify deletion assert deleted_env_var.id == sandbox_env_var_fixture.id # Verify it no longer exists - env_vars = server.sandbox_config_manager.list_sandbox_env_vars(sandbox_config_id=sandbox_config_fixture.id, actor=default_user) + env_vars = await server.sandbox_config_manager.list_sandbox_env_vars_async( + sandbox_config_id=sandbox_config_fixture.id, actor=default_user + ) assert sandbox_env_var_fixture.id not in [env_var.id for env_var in env_vars] -def test_list_sandbox_env_vars(server: SyncServer, sandbox_config_fixture, default_user): +@pytest.mark.asyncio +async def test_list_sandbox_env_vars(server: SyncServer, sandbox_config_fixture, default_user, event_loop): # Creating multiple environment variables env_var_create_a = SandboxEnvironmentVariableCreate(key="VAR1", value="value1") env_var_create_b = SandboxEnvironmentVariableCreate(key="VAR2", value="value2") - server.sandbox_config_manager.create_sandbox_env_var(env_var_create_a, sandbox_config_id=sandbox_config_fixture.id, actor=default_user) + await server.sandbox_config_manager.create_sandbox_env_var_async( + env_var_create_a, sandbox_config_id=sandbox_config_fixture.id, actor=default_user + ) if USING_SQLITE: time.sleep(CREATE_DELAY_SQLITE) - server.sandbox_config_manager.create_sandbox_env_var(env_var_create_b, sandbox_config_id=sandbox_config_fixture.id, actor=default_user) + await server.sandbox_config_manager.create_sandbox_env_var_async( + env_var_create_b, sandbox_config_id=sandbox_config_fixture.id, actor=default_user + ) # List env vars without pagination - env_vars = server.sandbox_config_manager.list_sandbox_env_vars(sandbox_config_id=sandbox_config_fixture.id, actor=default_user) + env_vars = await server.sandbox_config_manager.list_sandbox_env_vars_async( + sandbox_config_id=sandbox_config_fixture.id, actor=default_user + ) assert len(env_vars) >= 2 # List env vars with pagination - paginated_env_vars = server.sandbox_config_manager.list_sandbox_env_vars( + paginated_env_vars = await server.sandbox_config_manager.list_sandbox_env_vars_async( sandbox_config_id=sandbox_config_fixture.id, actor=default_user, limit=1 ) assert len(paginated_env_vars) == 1 - next_page = server.sandbox_config_manager.list_sandbox_env_vars( + next_page = await server.sandbox_config_manager.list_sandbox_env_vars_async( sandbox_config_id=sandbox_config_fixture.id, actor=default_user, after=paginated_env_vars[-1].id, limit=1 ) assert len(next_page) == 1 assert next_page[0].id != paginated_env_vars[0].id -def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture, default_user): - retrieved_env_var = server.sandbox_config_manager.get_sandbox_env_var_by_key_and_sandbox_config_id( +@pytest.mark.asyncio +async def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture, default_user, event_loop): + retrieved_env_var = await server.sandbox_config_manager.get_sandbox_env_var_by_key_and_sandbox_config_id_async( sandbox_env_var_fixture.key, sandbox_env_var_fixture.sandbox_config_id, actor=default_user )