feat: add profile auth method for bedrock (#695)

This commit is contained in:
Ari Webb
2026-01-28 11:33:43 -08:00
committed by GitHub
parent b304107a60
commit ebdf78302d
4 changed files with 601 additions and 47 deletions

View File

@@ -1,14 +1,20 @@
import { Box, Text, useInput } from "ink";
import { useCallback, useEffect, useRef, useState } from "react";
import {
type AuthMethod,
BYOK_PROVIDERS,
type ByokProvider,
checkProviderApiKey,
createOrUpdateProvider,
getConnectedProviders,
type ProviderField,
type ProviderResponse,
removeProviderByName,
} from "../../providers/byok-providers";
import {
type AwsProfile,
parseAwsCredentials,
} from "../../utils/aws-credentials";
import { useTerminalWidth } from "../hooks/useTerminalWidth";
import { colors } from "./colors";
@@ -17,7 +23,9 @@ const SOLID_LINE = "─";
type ViewState =
| { type: "list" }
| { type: "input"; provider: ByokProvider }
| { type: "multiInput"; provider: ByokProvider }
| { type: "multiInput"; provider: ByokProvider; authMethod?: AuthMethod }
| { type: "methodSelect"; provider: ByokProvider }
| { type: "profileSelect"; provider: ByokProvider }
| { type: "options"; provider: ByokProvider; providerId: string };
type ValidationState = "idle" | "validating" | "valid" | "invalid";
@@ -50,6 +58,12 @@ export function ProviderSelector({
// Multi-field input state (for providers like Bedrock)
const [fieldValues, setFieldValues] = useState<Record<string, string>>({});
const [focusedFieldIndex, setFocusedFieldIndex] = useState(0);
// Auth method selection state (for providers with multiple auth options)
const [methodIndex, setMethodIndex] = useState(0);
// AWS profile selection state
const [awsProfiles, setAwsProfiles] = useState<AwsProfile[]>([]);
const [profileIndex, setProfileIndex] = useState(0);
const [isLoadingProfiles, setIsLoadingProfiles] = useState(false);
const mountedRef = useRef(true);
useEffect(() => {
@@ -111,8 +125,12 @@ export function ProviderSelector({
setViewState({ type: "options", provider, providerId });
setOptionIndex(0);
}
} else if ("authMethods" in provider && provider.authMethods) {
// Provider with multiple auth methods - show method selection
setViewState({ type: "methodSelect", provider });
setMethodIndex(0);
} else if ("fields" in provider && provider.fields) {
// Multi-field provider (like Bedrock) - show multi-input view
// Multi-field provider - show multi-input view
setViewState({ type: "multiInput", provider });
setFieldValues({});
setFocusedFieldIndex(0);
@@ -129,6 +147,69 @@ export function ProviderSelector({
[isConnected, getProviderId, onStartOAuth],
);
// Handle selecting an auth method
const handleSelectAuthMethod = useCallback(
async (provider: ByokProvider, authMethod: AuthMethod) => {
// Special handling for profile method - load AWS profiles first
if (authMethod.id === "profile") {
setIsLoadingProfiles(true);
setViewState({ type: "profileSelect", provider });
setProfileIndex(0);
// Load profiles asynchronously
parseAwsCredentials()
.then((profiles) => {
if (mountedRef.current) {
setAwsProfiles(profiles);
setIsLoadingProfiles(false);
}
})
.catch((err) => {
// eslint-disable-next-line no-console
console.error("Failed to parse AWS credentials:", err);
if (mountedRef.current) {
setAwsProfiles([]);
setIsLoadingProfiles(false);
}
});
return;
}
setViewState({ type: "multiInput", provider, authMethod });
setFieldValues({});
setFocusedFieldIndex(0);
setValidationState("idle");
setValidationError(null);
},
[],
);
// Handle selecting an AWS profile - pre-fill IAM fields with credentials
const handleSelectAwsProfile = useCallback(
(provider: ByokProvider, profile: AwsProfile) => {
// Find the IAM auth method to use its fields
const iamMethod =
"authMethods" in provider
? provider.authMethods?.find((m) => m.id === "iam")
: undefined;
if (!iamMethod) return;
// Pre-fill field values from the profile
setFieldValues({
accessKey: profile.accessKeyId || "",
apiKey: profile.secretAccessKey || "",
region: profile.region || "",
});
setViewState({ type: "multiInput", provider, authMethod: iamMethod });
setFocusedFieldIndex(profile.region ? 0 : 2); // Focus region if not set
setValidationState("idle");
setValidationError(null);
},
[],
);
// Handle API key validation and saving
const handleValidateAndSave = useCallback(async () => {
if (viewState.type !== "input") return;
@@ -185,10 +266,13 @@ export function ProviderSelector({
// Handle multi-field validation and saving (for providers like Bedrock)
const handleMultiFieldValidateAndSave = useCallback(async () => {
if (viewState.type !== "multiInput") return;
if (!("fields" in viewState.provider) || !viewState.provider.fields) return;
const { provider } = viewState;
const fields = provider.fields;
const { provider, authMethod } = viewState;
// Get fields from authMethod if present, otherwise from provider
const fields: ProviderField[] | undefined =
authMethod?.fields ||
("fields" in provider ? (provider.fields as ProviderField[]) : undefined);
if (!fields) return;
// Check all required fields are filled
const allFilled = fields.every((field) => fieldValues[field.key]?.trim());
@@ -197,6 +281,7 @@ export function ProviderSelector({
const apiKey = fieldValues.apiKey?.trim() || "";
const accessKey = fieldValues.accessKey?.trim();
const region = fieldValues.region?.trim();
const profile = fieldValues.profile?.trim();
// If already validated, save
if (validationState === "valid") {
@@ -207,6 +292,7 @@ export function ProviderSelector({
apiKey,
accessKey,
region,
profile,
);
// Refresh connected providers
const providers = await getConnectedProviders();
@@ -237,6 +323,7 @@ export function ProviderSelector({
apiKey,
accessKey,
region,
profile,
);
if (mountedRef.current) {
setValidationState("valid");
@@ -269,16 +356,6 @@ export function ProviderSelector({
}
}, [viewState]);
// Handle update key option
const handleUpdateKey = useCallback(() => {
if (viewState.type !== "options") return;
const { provider } = viewState;
setViewState({ type: "input", provider });
setApiKeyInput("");
setValidationState("idle");
setValidationError(null);
}, [viewState]);
useInput((input, key) => {
// CTRL-C: immediately cancel
if (key.ctrl && input === "c") {
@@ -328,16 +405,71 @@ export function ProviderSelector({
setValidationError(null);
}
}
} else if (viewState.type === "multiInput") {
if (!("fields" in viewState.provider) || !viewState.provider.fields)
} else if (viewState.type === "methodSelect") {
// Handle auth method selection
if (
!("authMethods" in viewState.provider) ||
!viewState.provider.authMethods
)
return;
const fields = viewState.provider.fields;
const currentField = fields[focusedFieldIndex];
if (!currentField) return;
const authMethods = viewState.provider.authMethods;
if (key.escape) {
// Back to list
setViewState({ type: "list" });
setMethodIndex(0);
} else if (key.upArrow) {
setMethodIndex((prev) => Math.max(0, prev - 1));
} else if (key.downArrow) {
setMethodIndex((prev) => Math.min(authMethods.length - 1, prev + 1));
} else if (key.return) {
const selectedMethod = authMethods[methodIndex];
if (selectedMethod) {
handleSelectAuthMethod(viewState.provider, selectedMethod);
}
}
} else if (viewState.type === "profileSelect") {
// Handle AWS profile selection
if (isLoadingProfiles) return;
if (key.escape) {
// Back to method select
setViewState({ type: "methodSelect", provider: viewState.provider });
setMethodIndex(0);
setAwsProfiles([]);
setProfileIndex(0);
} else if (key.upArrow) {
setProfileIndex((prev) => Math.max(0, prev - 1));
} else if (key.downArrow) {
setProfileIndex((prev) => Math.min(awsProfiles.length - 1, prev + 1));
} else if (key.return) {
const selectedProfile = awsProfiles[profileIndex];
if (selectedProfile) {
handleSelectAwsProfile(viewState.provider, selectedProfile);
}
}
} else if (viewState.type === "multiInput") {
// Get fields from authMethod if present, otherwise from provider
const fields: ProviderField[] | undefined =
viewState.authMethod?.fields ||
("fields" in viewState.provider
? (viewState.provider.fields as ProviderField[])
: undefined);
if (!fields) return;
const currentField = fields[focusedFieldIndex];
if (!currentField) return;
if (key.escape) {
// Back to method select if provider has authMethods, otherwise back to list
if (
"authMethods" in viewState.provider &&
viewState.provider.authMethods
) {
setViewState({ type: "methodSelect", provider: viewState.provider });
setMethodIndex(0);
} else {
setViewState({ type: "list" });
}
setFieldValues({});
setFocusedFieldIndex(0);
setValidationState("idle");
@@ -377,7 +509,7 @@ export function ProviderSelector({
}
}
} else if (viewState.type === "options") {
const options = ["Update API key", "Disconnect", "Back"];
const options = ["Disconnect", "Back"];
if (key.escape) {
setViewState({ type: "list" });
} else if (key.upArrow) {
@@ -386,8 +518,6 @@ export function ProviderSelector({
setOptionIndex((prev) => Math.min(options.length - 1, prev + 1));
} else if (key.return) {
if (optionIndex === 0) {
handleUpdateKey();
} else if (optionIndex === 1) {
handleDisconnect();
} else {
setViewState({ type: "list" });
@@ -517,17 +647,165 @@ export function ProviderSelector({
);
};
// Render multi-input view (for providers like Bedrock)
const renderMultiInputView = () => {
if (viewState.type !== "multiInput") return null;
if (!("fields" in viewState.provider) || !viewState.provider.fields)
// Render method select view (for providers with multiple auth options)
const renderMethodSelectView = () => {
if (viewState.type !== "methodSelect") return null;
if (
!("authMethods" in viewState.provider) ||
!viewState.provider.authMethods
)
return null;
const { provider } = viewState;
const fields = provider.fields;
const authMethods = viewState.provider.authMethods;
return (
<>
<Box flexDirection="column" marginBottom={1}>
<Text bold color={colors.selector.title}>
Connect {provider.displayName}
</Text>
<Text dimColor>Select authentication method</Text>
</Box>
<Box flexDirection="column">
{authMethods.map((method, index) => {
const isSelected = index === methodIndex;
return (
<Box key={method.id} flexDirection="row">
<Text
color={
isSelected ? colors.selector.itemHighlighted : undefined
}
>
{isSelected ? "> " : " "}
</Text>
<Text
bold={isSelected}
color={
isSelected ? colors.selector.itemHighlighted : undefined
}
>
{method.label}
</Text>
<Text dimColor> · {method.description}</Text>
</Box>
);
})}
</Box>
<Box marginTop={1}>
<Text dimColor>{" "}Enter select · navigate · Esc back</Text>
</Box>
</>
);
};
// Render AWS profile select view
const renderProfileSelectView = () => {
if (viewState.type !== "profileSelect") return null;
const { provider } = viewState;
if (isLoadingProfiles) {
return (
<Box flexDirection="column" marginBottom={1}>
<Text bold color={colors.selector.title}>
Connect {provider.displayName}
</Text>
<Text dimColor>Loading AWS profiles...</Text>
</Box>
);
}
if (awsProfiles.length === 0) {
return (
<>
<Box flexDirection="column" marginBottom={1}>
<Text bold color={colors.selector.title}>
Connect {provider.displayName}
</Text>
<Text color="yellow">No AWS profiles found</Text>
<Text dimColor>
Check that ~/.aws/credentials exists and contains valid profiles.
</Text>
</Box>
<Box marginTop={1}>
<Text dimColor>{" "}Esc back</Text>
</Box>
</>
);
}
return (
<>
<Box flexDirection="column" marginBottom={1}>
<Text bold color={colors.selector.title}>
Connect {provider.displayName}
</Text>
<Text dimColor>Select AWS profile from ~/.aws/credentials</Text>
</Box>
<Box flexDirection="column">
{awsProfiles.map((profile, index) => {
const isSelected = index === profileIndex;
const hasCredentials =
profile.accessKeyId && profile.secretAccessKey;
return (
<Box key={profile.name} flexDirection="row">
<Text
color={
isSelected ? colors.selector.itemHighlighted : undefined
}
>
{isSelected ? "> " : " "}
</Text>
<Text
bold={isSelected}
color={
isSelected ? colors.selector.itemHighlighted : undefined
}
>
{profile.name}
</Text>
<Text dimColor>
{" · "}
{hasCredentials ? (
<>
{profile.accessKeyId?.slice(0, 8)}...
{profile.region && ` · ${profile.region}`}
</>
) : (
<Text color="yellow">missing credentials</Text>
)}
</Text>
</Box>
);
})}
</Box>
<Box marginTop={1}>
<Text dimColor>{" "}Enter select · navigate · Esc back</Text>
</Box>
</>
);
};
// Render multi-input view (for providers like Bedrock)
const renderMultiInputView = () => {
if (viewState.type !== "multiInput") return null;
const { provider, authMethod } = viewState;
// Get fields from authMethod if present, otherwise from provider
const fields: ProviderField[] | undefined =
authMethod?.fields ||
("fields" in provider ? (provider.fields as ProviderField[]) : undefined);
if (!fields) return null;
// Check if all fields are filled
const allFilled = fields.every((field) => fieldValues[field.key]?.trim());
const allFilled = fields.every((field: ProviderField) =>
fieldValues[field.key]?.trim(),
);
const statusText =
validationState === "validating"
@@ -545,23 +823,30 @@ export function ProviderSelector({
? "red"
: undefined;
const hasAuthMethods = "authMethods" in provider && provider.authMethods;
const escText = hasAuthMethods ? "Esc back" : "Esc cancel";
const footerText =
validationState === "valid"
? "Enter to save · Esc cancel"
? `Enter to save · ${escText}`
: allFilled
? "Enter to validate · Tab/↑↓ navigate · Esc cancel"
: "Tab/↑↓ navigate · Esc cancel";
? `Enter to validate · Tab/↑↓ navigate · ${escText}`
: `Tab/↑↓ navigate · ${escText}`;
// Build title - include auth method name if present
const title = authMethod
? `${provider.displayName} · ${authMethod.label}`
: provider.displayName;
return (
<>
<Box flexDirection="column" marginBottom={1}>
<Text bold color={colors.selector.title}>
Connect {provider.displayName}
Connect {title}
</Text>
</Box>
<Box flexDirection="column">
{fields.map((field, index) => {
{fields.map((field: ProviderField, index: number) => {
const isFocused = index === focusedFieldIndex;
const value = fieldValues[field.key] || "";
const displayValue = field.secret ? maskApiKey(value) : value;
@@ -620,7 +905,7 @@ export function ProviderSelector({
const renderOptionsView = () => {
if (viewState.type !== "options") return null;
const { provider } = viewState;
const options = ["Update API key", "Disconnect", "Back"];
const options = ["Disconnect", "Back"];
return (
<>
@@ -677,6 +962,8 @@ export function ProviderSelector({
{viewState.type === "list" && renderListView()}
{viewState.type === "input" && renderInputView()}
{viewState.type === "methodSelect" && renderMethodSelectView()}
{viewState.type === "profileSelect" && renderProfileSelectView()}
{viewState.type === "multiInput" && renderMultiInputView()}
{viewState.type === "options" && renderOptionsView()}
</Box>