feat: add profile auth method for bedrock (#695)
This commit is contained in:
@@ -8,6 +8,10 @@ import {
|
||||
startLocalOAuthServer,
|
||||
startOpenAIOAuth,
|
||||
} from "../../auth/openai-oauth";
|
||||
import {
|
||||
getProviderByName,
|
||||
removeProviderByName,
|
||||
} from "../../providers/byok-providers";
|
||||
import {
|
||||
createOrUpdateMinimaxProvider,
|
||||
getMinimaxProvider,
|
||||
@@ -116,18 +120,23 @@ export async function handleConnect(
|
||||
ctx.buffersRef,
|
||||
ctx.refreshDerived,
|
||||
msg,
|
||||
"Usage: /connect <provider> [options]\n\nAvailable providers:\n \u2022 codex - Connect via OAuth to authenticate with ChatGPT Plus/Pro\n \u2022 zai <api_key> - Connect to zAI with your API key\n \u2022 minimax <api_key> - Connect to MiniMax with your API key",
|
||||
"Usage: /connect <provider> [options]\n\nAvailable providers:\n \u2022 codex - Connect via OAuth to authenticate with ChatGPT Plus/Pro\n \u2022 zai <api_key> - Connect to zAI with your API key\n \u2022 minimax <api_key> - Connect to MiniMax with your API key\n \u2022 bedrock <method> - Connect to AWS Bedrock (iam/profile/default)",
|
||||
false,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (provider !== "codex" && provider !== "zai" && provider !== "minimax") {
|
||||
if (
|
||||
provider !== "codex" &&
|
||||
provider !== "zai" &&
|
||||
provider !== "minimax" &&
|
||||
provider !== "bedrock"
|
||||
) {
|
||||
addCommandResult(
|
||||
ctx.buffersRef,
|
||||
ctx.refreshDerived,
|
||||
msg,
|
||||
`Error: Unknown provider "${provider}"\n\nAvailable providers: codex, zai, minimax\nUsage: /connect <provider> [options]`,
|
||||
`Error: Unknown provider "${provider}"\n\nAvailable providers: codex, zai, minimax, bedrock\nUsage: /connect <provider> [options]`,
|
||||
false,
|
||||
);
|
||||
return;
|
||||
@@ -145,6 +154,12 @@ export async function handleConnect(
|
||||
return;
|
||||
}
|
||||
|
||||
// Bedrock is handled here
|
||||
if (provider === "bedrock") {
|
||||
await handleConnectBedrock(ctx, msg);
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle /connect codex
|
||||
await handleConnectCodex(ctx, msg);
|
||||
}
|
||||
@@ -445,6 +460,69 @@ async function handleDisconnectMinimax(
|
||||
}
|
||||
}
|
||||
|
||||
const BEDROCK_PROVIDER_NAME = "lc-bedrock";
|
||||
|
||||
/**
|
||||
* Handle /disconnect bedrock
|
||||
*/
|
||||
async function handleDisconnectBedrock(
|
||||
ctx: ConnectCommandContext,
|
||||
msg: string,
|
||||
): Promise<void> {
|
||||
// Check if Bedrock provider exists
|
||||
const existing = await getProviderByName(BEDROCK_PROVIDER_NAME);
|
||||
if (!existing) {
|
||||
addCommandResult(
|
||||
ctx.buffersRef,
|
||||
ctx.refreshDerived,
|
||||
msg,
|
||||
'Not currently connected to AWS Bedrock.\n\nUse /connect and select "AWS Bedrock" to connect.',
|
||||
false,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
// Show running status
|
||||
const cmdId = addCommandResult(
|
||||
ctx.buffersRef,
|
||||
ctx.refreshDerived,
|
||||
msg,
|
||||
"Disconnecting from AWS Bedrock...",
|
||||
true,
|
||||
"running",
|
||||
);
|
||||
|
||||
ctx.setCommandRunning(true);
|
||||
|
||||
try {
|
||||
// Remove provider from Letta
|
||||
await removeProviderByName(BEDROCK_PROVIDER_NAME);
|
||||
|
||||
updateCommandResult(
|
||||
ctx.buffersRef,
|
||||
ctx.refreshDerived,
|
||||
cmdId,
|
||||
msg,
|
||||
`\u2713 Disconnected from AWS Bedrock.\n\n` +
|
||||
`Provider '${BEDROCK_PROVIDER_NAME}' removed from Letta.`,
|
||||
true,
|
||||
"finished",
|
||||
);
|
||||
} catch (error) {
|
||||
updateCommandResult(
|
||||
ctx.buffersRef,
|
||||
ctx.refreshDerived,
|
||||
cmdId,
|
||||
msg,
|
||||
`\u2717 Failed to disconnect from Bedrock: ${getErrorMessage(error)}`,
|
||||
false,
|
||||
"finished",
|
||||
);
|
||||
} finally {
|
||||
ctx.setCommandRunning(false);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle /connect minimax command
|
||||
* Usage: /connect minimax <api_key>
|
||||
@@ -515,6 +593,27 @@ export async function handleConnectMinimax(
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle /connect bedrock command
|
||||
* Redirects users to use the interactive /connect UI
|
||||
*/
|
||||
export async function handleConnectBedrock(
|
||||
ctx: ConnectCommandContext,
|
||||
msg: string,
|
||||
): Promise<void> {
|
||||
addCommandResult(
|
||||
ctx.buffersRef,
|
||||
ctx.refreshDerived,
|
||||
msg,
|
||||
'To connect AWS Bedrock, use /connect and select "AWS Bedrock" from the list.\n\n' +
|
||||
"The interactive UI will guide you through:\n" +
|
||||
" • Choosing an authentication method (IAM, Profile, or Default)\n" +
|
||||
" • Entering your credentials\n" +
|
||||
" • Validating the connection",
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle /disconnect command
|
||||
* Usage: /disconnect <provider>
|
||||
@@ -532,7 +631,7 @@ export async function handleDisconnect(
|
||||
ctx.buffersRef,
|
||||
ctx.refreshDerived,
|
||||
msg,
|
||||
"Usage: /disconnect <provider>\n\nAvailable providers: codex, claude, zai, minimax",
|
||||
"Usage: /disconnect <provider>\n\nAvailable providers: codex, claude, zai, minimax, bedrock",
|
||||
false,
|
||||
);
|
||||
return;
|
||||
@@ -550,6 +649,12 @@ export async function handleDisconnect(
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle /disconnect bedrock
|
||||
if (provider === "bedrock") {
|
||||
await handleDisconnectBedrock(ctx, msg);
|
||||
return;
|
||||
}
|
||||
|
||||
// Handle /disconnect codex
|
||||
if (provider === "codex") {
|
||||
await handleDisconnectCodex(ctx, msg);
|
||||
@@ -567,7 +672,7 @@ export async function handleDisconnect(
|
||||
ctx.buffersRef,
|
||||
ctx.refreshDerived,
|
||||
msg,
|
||||
`Error: Unknown provider "${provider}"\n\nAvailable providers: codex, claude, zai, minimax\nUsage: /disconnect <provider>`,
|
||||
`Error: Unknown provider "${provider}"\n\nAvailable providers: codex, claude, zai, minimax, bedrock\nUsage: /disconnect <provider>`,
|
||||
false,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
import { Box, Text, useInput } from "ink";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import {
|
||||
type AuthMethod,
|
||||
BYOK_PROVIDERS,
|
||||
type ByokProvider,
|
||||
checkProviderApiKey,
|
||||
createOrUpdateProvider,
|
||||
getConnectedProviders,
|
||||
type ProviderField,
|
||||
type ProviderResponse,
|
||||
removeProviderByName,
|
||||
} from "../../providers/byok-providers";
|
||||
import {
|
||||
type AwsProfile,
|
||||
parseAwsCredentials,
|
||||
} from "../../utils/aws-credentials";
|
||||
import { useTerminalWidth } from "../hooks/useTerminalWidth";
|
||||
import { colors } from "./colors";
|
||||
|
||||
@@ -17,7 +23,9 @@ const SOLID_LINE = "─";
|
||||
type ViewState =
|
||||
| { type: "list" }
|
||||
| { type: "input"; provider: ByokProvider }
|
||||
| { type: "multiInput"; provider: ByokProvider }
|
||||
| { type: "multiInput"; provider: ByokProvider; authMethod?: AuthMethod }
|
||||
| { type: "methodSelect"; provider: ByokProvider }
|
||||
| { type: "profileSelect"; provider: ByokProvider }
|
||||
| { type: "options"; provider: ByokProvider; providerId: string };
|
||||
|
||||
type ValidationState = "idle" | "validating" | "valid" | "invalid";
|
||||
@@ -50,6 +58,12 @@ export function ProviderSelector({
|
||||
// Multi-field input state (for providers like Bedrock)
|
||||
const [fieldValues, setFieldValues] = useState<Record<string, string>>({});
|
||||
const [focusedFieldIndex, setFocusedFieldIndex] = useState(0);
|
||||
// Auth method selection state (for providers with multiple auth options)
|
||||
const [methodIndex, setMethodIndex] = useState(0);
|
||||
// AWS profile selection state
|
||||
const [awsProfiles, setAwsProfiles] = useState<AwsProfile[]>([]);
|
||||
const [profileIndex, setProfileIndex] = useState(0);
|
||||
const [isLoadingProfiles, setIsLoadingProfiles] = useState(false);
|
||||
|
||||
const mountedRef = useRef(true);
|
||||
useEffect(() => {
|
||||
@@ -111,8 +125,12 @@ export function ProviderSelector({
|
||||
setViewState({ type: "options", provider, providerId });
|
||||
setOptionIndex(0);
|
||||
}
|
||||
} else if ("authMethods" in provider && provider.authMethods) {
|
||||
// Provider with multiple auth methods - show method selection
|
||||
setViewState({ type: "methodSelect", provider });
|
||||
setMethodIndex(0);
|
||||
} else if ("fields" in provider && provider.fields) {
|
||||
// Multi-field provider (like Bedrock) - show multi-input view
|
||||
// Multi-field provider - show multi-input view
|
||||
setViewState({ type: "multiInput", provider });
|
||||
setFieldValues({});
|
||||
setFocusedFieldIndex(0);
|
||||
@@ -129,6 +147,69 @@ export function ProviderSelector({
|
||||
[isConnected, getProviderId, onStartOAuth],
|
||||
);
|
||||
|
||||
// Handle selecting an auth method
|
||||
const handleSelectAuthMethod = useCallback(
|
||||
async (provider: ByokProvider, authMethod: AuthMethod) => {
|
||||
// Special handling for profile method - load AWS profiles first
|
||||
if (authMethod.id === "profile") {
|
||||
setIsLoadingProfiles(true);
|
||||
setViewState({ type: "profileSelect", provider });
|
||||
setProfileIndex(0);
|
||||
|
||||
// Load profiles asynchronously
|
||||
parseAwsCredentials()
|
||||
.then((profiles) => {
|
||||
if (mountedRef.current) {
|
||||
setAwsProfiles(profiles);
|
||||
setIsLoadingProfiles(false);
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
// eslint-disable-next-line no-console
|
||||
console.error("Failed to parse AWS credentials:", err);
|
||||
if (mountedRef.current) {
|
||||
setAwsProfiles([]);
|
||||
setIsLoadingProfiles(false);
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
setViewState({ type: "multiInput", provider, authMethod });
|
||||
setFieldValues({});
|
||||
setFocusedFieldIndex(0);
|
||||
setValidationState("idle");
|
||||
setValidationError(null);
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
// Handle selecting an AWS profile - pre-fill IAM fields with credentials
|
||||
const handleSelectAwsProfile = useCallback(
|
||||
(provider: ByokProvider, profile: AwsProfile) => {
|
||||
// Find the IAM auth method to use its fields
|
||||
const iamMethod =
|
||||
"authMethods" in provider
|
||||
? provider.authMethods?.find((m) => m.id === "iam")
|
||||
: undefined;
|
||||
|
||||
if (!iamMethod) return;
|
||||
|
||||
// Pre-fill field values from the profile
|
||||
setFieldValues({
|
||||
accessKey: profile.accessKeyId || "",
|
||||
apiKey: profile.secretAccessKey || "",
|
||||
region: profile.region || "",
|
||||
});
|
||||
|
||||
setViewState({ type: "multiInput", provider, authMethod: iamMethod });
|
||||
setFocusedFieldIndex(profile.region ? 0 : 2); // Focus region if not set
|
||||
setValidationState("idle");
|
||||
setValidationError(null);
|
||||
},
|
||||
[],
|
||||
);
|
||||
|
||||
// Handle API key validation and saving
|
||||
const handleValidateAndSave = useCallback(async () => {
|
||||
if (viewState.type !== "input") return;
|
||||
@@ -185,10 +266,13 @@ export function ProviderSelector({
|
||||
// Handle multi-field validation and saving (for providers like Bedrock)
|
||||
const handleMultiFieldValidateAndSave = useCallback(async () => {
|
||||
if (viewState.type !== "multiInput") return;
|
||||
if (!("fields" in viewState.provider) || !viewState.provider.fields) return;
|
||||
|
||||
const { provider } = viewState;
|
||||
const fields = provider.fields;
|
||||
const { provider, authMethod } = viewState;
|
||||
// Get fields from authMethod if present, otherwise from provider
|
||||
const fields: ProviderField[] | undefined =
|
||||
authMethod?.fields ||
|
||||
("fields" in provider ? (provider.fields as ProviderField[]) : undefined);
|
||||
if (!fields) return;
|
||||
|
||||
// Check all required fields are filled
|
||||
const allFilled = fields.every((field) => fieldValues[field.key]?.trim());
|
||||
@@ -197,6 +281,7 @@ export function ProviderSelector({
|
||||
const apiKey = fieldValues.apiKey?.trim() || "";
|
||||
const accessKey = fieldValues.accessKey?.trim();
|
||||
const region = fieldValues.region?.trim();
|
||||
const profile = fieldValues.profile?.trim();
|
||||
|
||||
// If already validated, save
|
||||
if (validationState === "valid") {
|
||||
@@ -207,6 +292,7 @@ export function ProviderSelector({
|
||||
apiKey,
|
||||
accessKey,
|
||||
region,
|
||||
profile,
|
||||
);
|
||||
// Refresh connected providers
|
||||
const providers = await getConnectedProviders();
|
||||
@@ -237,6 +323,7 @@ export function ProviderSelector({
|
||||
apiKey,
|
||||
accessKey,
|
||||
region,
|
||||
profile,
|
||||
);
|
||||
if (mountedRef.current) {
|
||||
setValidationState("valid");
|
||||
@@ -269,16 +356,6 @@ export function ProviderSelector({
|
||||
}
|
||||
}, [viewState]);
|
||||
|
||||
// Handle update key option
|
||||
const handleUpdateKey = useCallback(() => {
|
||||
if (viewState.type !== "options") return;
|
||||
const { provider } = viewState;
|
||||
setViewState({ type: "input", provider });
|
||||
setApiKeyInput("");
|
||||
setValidationState("idle");
|
||||
setValidationError(null);
|
||||
}, [viewState]);
|
||||
|
||||
useInput((input, key) => {
|
||||
// CTRL-C: immediately cancel
|
||||
if (key.ctrl && input === "c") {
|
||||
@@ -328,16 +405,71 @@ export function ProviderSelector({
|
||||
setValidationError(null);
|
||||
}
|
||||
}
|
||||
} else if (viewState.type === "multiInput") {
|
||||
if (!("fields" in viewState.provider) || !viewState.provider.fields)
|
||||
} else if (viewState.type === "methodSelect") {
|
||||
// Handle auth method selection
|
||||
if (
|
||||
!("authMethods" in viewState.provider) ||
|
||||
!viewState.provider.authMethods
|
||||
)
|
||||
return;
|
||||
const fields = viewState.provider.fields;
|
||||
const currentField = fields[focusedFieldIndex];
|
||||
if (!currentField) return;
|
||||
const authMethods = viewState.provider.authMethods;
|
||||
|
||||
if (key.escape) {
|
||||
// Back to list
|
||||
setViewState({ type: "list" });
|
||||
setMethodIndex(0);
|
||||
} else if (key.upArrow) {
|
||||
setMethodIndex((prev) => Math.max(0, prev - 1));
|
||||
} else if (key.downArrow) {
|
||||
setMethodIndex((prev) => Math.min(authMethods.length - 1, prev + 1));
|
||||
} else if (key.return) {
|
||||
const selectedMethod = authMethods[methodIndex];
|
||||
if (selectedMethod) {
|
||||
handleSelectAuthMethod(viewState.provider, selectedMethod);
|
||||
}
|
||||
}
|
||||
} else if (viewState.type === "profileSelect") {
|
||||
// Handle AWS profile selection
|
||||
if (isLoadingProfiles) return;
|
||||
|
||||
if (key.escape) {
|
||||
// Back to method select
|
||||
setViewState({ type: "methodSelect", provider: viewState.provider });
|
||||
setMethodIndex(0);
|
||||
setAwsProfiles([]);
|
||||
setProfileIndex(0);
|
||||
} else if (key.upArrow) {
|
||||
setProfileIndex((prev) => Math.max(0, prev - 1));
|
||||
} else if (key.downArrow) {
|
||||
setProfileIndex((prev) => Math.min(awsProfiles.length - 1, prev + 1));
|
||||
} else if (key.return) {
|
||||
const selectedProfile = awsProfiles[profileIndex];
|
||||
if (selectedProfile) {
|
||||
handleSelectAwsProfile(viewState.provider, selectedProfile);
|
||||
}
|
||||
}
|
||||
} else if (viewState.type === "multiInput") {
|
||||
// Get fields from authMethod if present, otherwise from provider
|
||||
const fields: ProviderField[] | undefined =
|
||||
viewState.authMethod?.fields ||
|
||||
("fields" in viewState.provider
|
||||
? (viewState.provider.fields as ProviderField[])
|
||||
: undefined);
|
||||
if (!fields) return;
|
||||
const currentField = fields[focusedFieldIndex];
|
||||
if (!currentField) return;
|
||||
|
||||
if (key.escape) {
|
||||
// Back to method select if provider has authMethods, otherwise back to list
|
||||
if (
|
||||
"authMethods" in viewState.provider &&
|
||||
viewState.provider.authMethods
|
||||
) {
|
||||
setViewState({ type: "methodSelect", provider: viewState.provider });
|
||||
setMethodIndex(0);
|
||||
} else {
|
||||
setViewState({ type: "list" });
|
||||
}
|
||||
setFieldValues({});
|
||||
setFocusedFieldIndex(0);
|
||||
setValidationState("idle");
|
||||
@@ -377,7 +509,7 @@ export function ProviderSelector({
|
||||
}
|
||||
}
|
||||
} else if (viewState.type === "options") {
|
||||
const options = ["Update API key", "Disconnect", "Back"];
|
||||
const options = ["Disconnect", "Back"];
|
||||
if (key.escape) {
|
||||
setViewState({ type: "list" });
|
||||
} else if (key.upArrow) {
|
||||
@@ -386,8 +518,6 @@ export function ProviderSelector({
|
||||
setOptionIndex((prev) => Math.min(options.length - 1, prev + 1));
|
||||
} else if (key.return) {
|
||||
if (optionIndex === 0) {
|
||||
handleUpdateKey();
|
||||
} else if (optionIndex === 1) {
|
||||
handleDisconnect();
|
||||
} else {
|
||||
setViewState({ type: "list" });
|
||||
@@ -517,17 +647,165 @@ export function ProviderSelector({
|
||||
);
|
||||
};
|
||||
|
||||
// Render multi-input view (for providers like Bedrock)
|
||||
const renderMultiInputView = () => {
|
||||
if (viewState.type !== "multiInput") return null;
|
||||
if (!("fields" in viewState.provider) || !viewState.provider.fields)
|
||||
// Render method select view (for providers with multiple auth options)
|
||||
const renderMethodSelectView = () => {
|
||||
if (viewState.type !== "methodSelect") return null;
|
||||
if (
|
||||
!("authMethods" in viewState.provider) ||
|
||||
!viewState.provider.authMethods
|
||||
)
|
||||
return null;
|
||||
|
||||
const { provider } = viewState;
|
||||
const fields = provider.fields;
|
||||
const authMethods = viewState.provider.authMethods;
|
||||
|
||||
return (
|
||||
<>
|
||||
<Box flexDirection="column" marginBottom={1}>
|
||||
<Text bold color={colors.selector.title}>
|
||||
Connect {provider.displayName}
|
||||
</Text>
|
||||
<Text dimColor>Select authentication method</Text>
|
||||
</Box>
|
||||
|
||||
<Box flexDirection="column">
|
||||
{authMethods.map((method, index) => {
|
||||
const isSelected = index === methodIndex;
|
||||
return (
|
||||
<Box key={method.id} flexDirection="row">
|
||||
<Text
|
||||
color={
|
||||
isSelected ? colors.selector.itemHighlighted : undefined
|
||||
}
|
||||
>
|
||||
{isSelected ? "> " : " "}
|
||||
</Text>
|
||||
<Text
|
||||
bold={isSelected}
|
||||
color={
|
||||
isSelected ? colors.selector.itemHighlighted : undefined
|
||||
}
|
||||
>
|
||||
{method.label}
|
||||
</Text>
|
||||
<Text dimColor> · {method.description}</Text>
|
||||
</Box>
|
||||
);
|
||||
})}
|
||||
</Box>
|
||||
|
||||
<Box marginTop={1}>
|
||||
<Text dimColor>{" "}Enter select · ↑↓ navigate · Esc back</Text>
|
||||
</Box>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
// Render AWS profile select view
|
||||
const renderProfileSelectView = () => {
|
||||
if (viewState.type !== "profileSelect") return null;
|
||||
|
||||
const { provider } = viewState;
|
||||
|
||||
if (isLoadingProfiles) {
|
||||
return (
|
||||
<Box flexDirection="column" marginBottom={1}>
|
||||
<Text bold color={colors.selector.title}>
|
||||
Connect {provider.displayName}
|
||||
</Text>
|
||||
<Text dimColor>Loading AWS profiles...</Text>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
if (awsProfiles.length === 0) {
|
||||
return (
|
||||
<>
|
||||
<Box flexDirection="column" marginBottom={1}>
|
||||
<Text bold color={colors.selector.title}>
|
||||
Connect {provider.displayName}
|
||||
</Text>
|
||||
<Text color="yellow">No AWS profiles found</Text>
|
||||
<Text dimColor>
|
||||
Check that ~/.aws/credentials exists and contains valid profiles.
|
||||
</Text>
|
||||
</Box>
|
||||
<Box marginTop={1}>
|
||||
<Text dimColor>{" "}Esc back</Text>
|
||||
</Box>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Box flexDirection="column" marginBottom={1}>
|
||||
<Text bold color={colors.selector.title}>
|
||||
Connect {provider.displayName}
|
||||
</Text>
|
||||
<Text dimColor>Select AWS profile from ~/.aws/credentials</Text>
|
||||
</Box>
|
||||
|
||||
<Box flexDirection="column">
|
||||
{awsProfiles.map((profile, index) => {
|
||||
const isSelected = index === profileIndex;
|
||||
const hasCredentials =
|
||||
profile.accessKeyId && profile.secretAccessKey;
|
||||
return (
|
||||
<Box key={profile.name} flexDirection="row">
|
||||
<Text
|
||||
color={
|
||||
isSelected ? colors.selector.itemHighlighted : undefined
|
||||
}
|
||||
>
|
||||
{isSelected ? "> " : " "}
|
||||
</Text>
|
||||
<Text
|
||||
bold={isSelected}
|
||||
color={
|
||||
isSelected ? colors.selector.itemHighlighted : undefined
|
||||
}
|
||||
>
|
||||
{profile.name}
|
||||
</Text>
|
||||
<Text dimColor>
|
||||
{" · "}
|
||||
{hasCredentials ? (
|
||||
<>
|
||||
{profile.accessKeyId?.slice(0, 8)}...
|
||||
{profile.region && ` · ${profile.region}`}
|
||||
</>
|
||||
) : (
|
||||
<Text color="yellow">missing credentials</Text>
|
||||
)}
|
||||
</Text>
|
||||
</Box>
|
||||
);
|
||||
})}
|
||||
</Box>
|
||||
|
||||
<Box marginTop={1}>
|
||||
<Text dimColor>{" "}Enter select · ↑↓ navigate · Esc back</Text>
|
||||
</Box>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
// Render multi-input view (for providers like Bedrock)
|
||||
const renderMultiInputView = () => {
|
||||
if (viewState.type !== "multiInput") return null;
|
||||
|
||||
const { provider, authMethod } = viewState;
|
||||
// Get fields from authMethod if present, otherwise from provider
|
||||
const fields: ProviderField[] | undefined =
|
||||
authMethod?.fields ||
|
||||
("fields" in provider ? (provider.fields as ProviderField[]) : undefined);
|
||||
if (!fields) return null;
|
||||
|
||||
// Check if all fields are filled
|
||||
const allFilled = fields.every((field) => fieldValues[field.key]?.trim());
|
||||
const allFilled = fields.every((field: ProviderField) =>
|
||||
fieldValues[field.key]?.trim(),
|
||||
);
|
||||
|
||||
const statusText =
|
||||
validationState === "validating"
|
||||
@@ -545,23 +823,30 @@ export function ProviderSelector({
|
||||
? "red"
|
||||
: undefined;
|
||||
|
||||
const hasAuthMethods = "authMethods" in provider && provider.authMethods;
|
||||
const escText = hasAuthMethods ? "Esc back" : "Esc cancel";
|
||||
const footerText =
|
||||
validationState === "valid"
|
||||
? "Enter to save · Esc cancel"
|
||||
? `Enter to save · ${escText}`
|
||||
: allFilled
|
||||
? "Enter to validate · Tab/↑↓ navigate · Esc cancel"
|
||||
: "Tab/↑↓ navigate · Esc cancel";
|
||||
? `Enter to validate · Tab/↑↓ navigate · ${escText}`
|
||||
: `Tab/↑↓ navigate · ${escText}`;
|
||||
|
||||
// Build title - include auth method name if present
|
||||
const title = authMethod
|
||||
? `${provider.displayName} · ${authMethod.label}`
|
||||
: provider.displayName;
|
||||
|
||||
return (
|
||||
<>
|
||||
<Box flexDirection="column" marginBottom={1}>
|
||||
<Text bold color={colors.selector.title}>
|
||||
Connect {provider.displayName}
|
||||
Connect {title}
|
||||
</Text>
|
||||
</Box>
|
||||
|
||||
<Box flexDirection="column">
|
||||
{fields.map((field, index) => {
|
||||
{fields.map((field: ProviderField, index: number) => {
|
||||
const isFocused = index === focusedFieldIndex;
|
||||
const value = fieldValues[field.key] || "";
|
||||
const displayValue = field.secret ? maskApiKey(value) : value;
|
||||
@@ -620,7 +905,7 @@ export function ProviderSelector({
|
||||
const renderOptionsView = () => {
|
||||
if (viewState.type !== "options") return null;
|
||||
const { provider } = viewState;
|
||||
const options = ["Update API key", "Disconnect", "Back"];
|
||||
const options = ["Disconnect", "Back"];
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -677,6 +962,8 @@ export function ProviderSelector({
|
||||
|
||||
{viewState.type === "list" && renderListView()}
|
||||
{viewState.type === "input" && renderInputView()}
|
||||
{viewState.type === "methodSelect" && renderMethodSelectView()}
|
||||
{viewState.type === "profileSelect" && renderProfileSelectView()}
|
||||
{viewState.type === "multiInput" && renderMultiInputView()}
|
||||
{viewState.type === "options" && renderOptionsView()}
|
||||
</Box>
|
||||
|
||||
@@ -15,6 +15,14 @@ export interface ProviderField {
|
||||
secret?: boolean; // If true, mask input like a password
|
||||
}
|
||||
|
||||
// Auth method definition for providers with multiple auth options
|
||||
export interface AuthMethod {
|
||||
id: string;
|
||||
label: string;
|
||||
description: string;
|
||||
fields: ProviderField[];
|
||||
}
|
||||
|
||||
// Provider configuration for the /connect UI
|
||||
export const BYOK_PROVIDERS = [
|
||||
{
|
||||
@@ -66,11 +74,31 @@ export const BYOK_PROVIDERS = [
|
||||
description: "Connect to Claude on Amazon Bedrock",
|
||||
providerType: "bedrock",
|
||||
providerName: "lc-bedrock",
|
||||
fields: [
|
||||
{ key: "accessKey", label: "AWS Access Key ID", placeholder: "AKIA..." },
|
||||
{ key: "apiKey", label: "AWS Secret Access Key", secret: true },
|
||||
{ key: "region", label: "AWS Region", placeholder: "us-east-1" },
|
||||
] as ProviderField[],
|
||||
authMethods: [
|
||||
{
|
||||
id: "iam",
|
||||
label: "AWS Access Keys",
|
||||
description: "Enter access key and secret key manually",
|
||||
fields: [
|
||||
{
|
||||
key: "accessKey",
|
||||
label: "AWS Access Key ID",
|
||||
placeholder: "AKIA...",
|
||||
},
|
||||
{ key: "apiKey", label: "AWS Secret Access Key", secret: true },
|
||||
{ key: "region", label: "AWS Region", placeholder: "us-east-1" },
|
||||
],
|
||||
},
|
||||
{
|
||||
id: "profile",
|
||||
label: "AWS Profile",
|
||||
description: "Load credentials from ~/.aws/credentials",
|
||||
fields: [
|
||||
{ key: "profile", label: "Profile Name", placeholder: "default" },
|
||||
{ key: "region", label: "AWS Region", placeholder: "us-east-1" },
|
||||
],
|
||||
},
|
||||
] as AuthMethod[],
|
||||
},
|
||||
] as const;
|
||||
|
||||
@@ -189,12 +217,14 @@ export async function checkProviderApiKey(
|
||||
apiKey: string,
|
||||
accessKey?: string,
|
||||
region?: string,
|
||||
profile?: string,
|
||||
): Promise<void> {
|
||||
await providersRequest<{ message: string }>("POST", "/v1/providers/check", {
|
||||
provider_type: providerType,
|
||||
api_key: apiKey,
|
||||
...(accessKey && { access_key: accessKey }),
|
||||
...(region && { region }),
|
||||
...(profile && { profile }),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -207,6 +237,7 @@ export async function createProvider(
|
||||
apiKey: string,
|
||||
accessKey?: string,
|
||||
region?: string,
|
||||
profile?: string,
|
||||
): Promise<ProviderResponse> {
|
||||
return providersRequest<ProviderResponse>("POST", "/v1/providers", {
|
||||
name: providerName,
|
||||
@@ -214,6 +245,7 @@ export async function createProvider(
|
||||
api_key: apiKey,
|
||||
...(accessKey && { access_key: accessKey }),
|
||||
...(region && { region }),
|
||||
...(profile && { profile }),
|
||||
});
|
||||
}
|
||||
|
||||
@@ -225,6 +257,7 @@ export async function updateProvider(
|
||||
apiKey: string,
|
||||
accessKey?: string,
|
||||
region?: string,
|
||||
profile?: string,
|
||||
): Promise<ProviderResponse> {
|
||||
return providersRequest<ProviderResponse>(
|
||||
"PATCH",
|
||||
@@ -233,6 +266,7 @@ export async function updateProvider(
|
||||
api_key: apiKey,
|
||||
...(accessKey && { access_key: accessKey }),
|
||||
...(region && { region }),
|
||||
...(profile && { profile }),
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -254,14 +288,22 @@ export async function createOrUpdateProvider(
|
||||
apiKey: string,
|
||||
accessKey?: string,
|
||||
region?: string,
|
||||
profile?: string,
|
||||
): Promise<ProviderResponse> {
|
||||
const existing = await getProviderByName(providerName);
|
||||
|
||||
if (existing) {
|
||||
return updateProvider(existing.id, apiKey, accessKey, region);
|
||||
return updateProvider(existing.id, apiKey, accessKey, region, profile);
|
||||
}
|
||||
|
||||
return createProvider(providerType, providerName, apiKey, accessKey, region);
|
||||
return createProvider(
|
||||
providerType,
|
||||
providerName,
|
||||
apiKey,
|
||||
accessKey,
|
||||
region,
|
||||
profile,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
120
src/utils/aws-credentials.ts
Normal file
120
src/utils/aws-credentials.ts
Normal file
@@ -0,0 +1,120 @@
|
||||
/**
|
||||
* Utility to parse AWS credentials from ~/.aws/credentials
|
||||
*/
|
||||
|
||||
import { readFile } from "node:fs/promises";
|
||||
import { homedir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
|
||||
export interface AwsProfile {
|
||||
name: string;
|
||||
accessKeyId?: string;
|
||||
secretAccessKey?: string;
|
||||
region?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse AWS credentials file and return list of profiles
|
||||
*/
|
||||
export async function parseAwsCredentials(): Promise<AwsProfile[]> {
|
||||
const credentialsPath = join(homedir(), ".aws", "credentials");
|
||||
const configPath = join(homedir(), ".aws", "config");
|
||||
|
||||
const profiles: Map<string, AwsProfile> = new Map();
|
||||
|
||||
// Parse credentials file
|
||||
try {
|
||||
const content = await readFile(credentialsPath, "utf-8");
|
||||
parseIniFile(content, profiles, false);
|
||||
} catch {
|
||||
// Credentials file doesn't exist or can't be read
|
||||
}
|
||||
|
||||
// Parse config file for region info
|
||||
try {
|
||||
const content = await readFile(configPath, "utf-8");
|
||||
parseIniFile(content, profiles, true);
|
||||
} catch {
|
||||
// Config file doesn't exist or can't be read
|
||||
}
|
||||
|
||||
return Array.from(profiles.values());
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse INI-style AWS config/credentials file
|
||||
*/
|
||||
function parseIniFile(
|
||||
content: string,
|
||||
profiles: Map<string, AwsProfile>,
|
||||
isConfig: boolean,
|
||||
): void {
|
||||
const lines = content.split("\n");
|
||||
let currentProfile: string | null = null;
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim();
|
||||
|
||||
// Skip empty lines and comments
|
||||
if (!trimmed || trimmed.startsWith("#") || trimmed.startsWith(";")) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for profile header
|
||||
const headerMatch = trimmed.match(/^\[(.+)\]$/);
|
||||
if (headerMatch?.[1]) {
|
||||
let profileName: string = headerMatch[1];
|
||||
// In config file, profiles are prefixed with "profile " (except default)
|
||||
if (isConfig && profileName.startsWith("profile ")) {
|
||||
profileName = profileName.slice(8);
|
||||
}
|
||||
currentProfile = profileName;
|
||||
|
||||
if (!profiles.has(profileName)) {
|
||||
profiles.set(profileName, { name: profileName });
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse key=value pairs
|
||||
if (currentProfile) {
|
||||
const kvMatch = trimmed.match(/^([^=]+)=(.*)$/);
|
||||
if (kvMatch?.[1] && kvMatch[2] !== undefined) {
|
||||
const key = kvMatch[1].trim();
|
||||
const value = kvMatch[2].trim();
|
||||
const profile = profiles.get(currentProfile);
|
||||
if (!profile) continue;
|
||||
|
||||
switch (key) {
|
||||
case "aws_access_key_id":
|
||||
profile.accessKeyId = value;
|
||||
break;
|
||||
case "aws_secret_access_key":
|
||||
profile.secretAccessKey = value;
|
||||
break;
|
||||
case "region":
|
||||
profile.region = value;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a specific profile by name
|
||||
*/
|
||||
export async function getAwsProfile(
|
||||
profileName: string,
|
||||
): Promise<AwsProfile | null> {
|
||||
const profiles = await parseAwsCredentials();
|
||||
return profiles.find((p) => p.name === profileName) || null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get list of available profile names
|
||||
*/
|
||||
export async function getAwsProfileNames(): Promise<string[]> {
|
||||
const profiles = await parseAwsCredentials();
|
||||
return profiles.map((p) => p.name);
|
||||
}
|
||||
Reference in New Issue
Block a user