diff --git a/src/cli/components/ProviderSelector.tsx b/src/cli/components/ProviderSelector.tsx index fa5b60e..92a8586 100644 --- a/src/cli/components/ProviderSelector.tsx +++ b/src/cli/components/ProviderSelector.tsx @@ -17,6 +17,7 @@ const SOLID_LINE = "─"; type ViewState = | { type: "list" } | { type: "input"; provider: ByokProvider } + | { type: "multiInput"; provider: ByokProvider } | { type: "options"; provider: ByokProvider; providerId: string }; type ValidationState = "idle" | "validating" | "valid" | "invalid"; @@ -46,6 +47,9 @@ export function ProviderSelector({ useState("idle"); const [validationError, setValidationError] = useState(null); const [optionIndex, setOptionIndex] = useState(0); + // Multi-field input state (for providers like Bedrock) + const [fieldValues, setFieldValues] = useState>({}); + const [focusedFieldIndex, setFocusedFieldIndex] = useState(0); const mountedRef = useRef(true); useEffect(() => { @@ -107,8 +111,15 @@ export function ProviderSelector({ setViewState({ type: "options", provider, providerId }); setOptionIndex(0); } + } else if ("fields" in provider && provider.fields) { + // Multi-field provider (like Bedrock) - show multi-input view + setViewState({ type: "multiInput", provider }); + setFieldValues({}); + setFocusedFieldIndex(0); + setValidationState("idle"); + setValidationError(null); } else { - // Show API key input for new provider + // Single API key input for regular providers setViewState({ type: "input", provider }); setApiKeyInput(""); setValidationState("idle"); @@ -171,6 +182,75 @@ export function ProviderSelector({ } }, [viewState, apiKeyInput, validationState]); + // Handle multi-field validation and saving (for providers like Bedrock) + const handleMultiFieldValidateAndSave = useCallback(async () => { + if (viewState.type !== "multiInput") return; + if (!("fields" in viewState.provider) || !viewState.provider.fields) return; + + const { provider } = viewState; + const fields = provider.fields; + + // Check all required fields are filled + const allFilled = fields.every((field) => fieldValues[field.key]?.trim()); + if (!allFilled) return; + + const apiKey = fieldValues.apiKey?.trim() || ""; + const accessKey = fieldValues.accessKey?.trim(); + const region = fieldValues.region?.trim(); + + // If already validated, save + if (validationState === "valid") { + try { + await createOrUpdateProvider( + provider.providerType, + provider.providerName, + apiKey, + accessKey, + region, + ); + // Refresh connected providers + const providers = await getConnectedProviders(); + if (mountedRef.current) { + setConnectedProviders(providers); + setViewState({ type: "list" }); + setFieldValues({}); + setValidationState("idle"); + } + } catch (err) { + if (mountedRef.current) { + setValidationError( + err instanceof Error ? err.message : "Failed to save", + ); + setValidationState("invalid"); + } + } + return; + } + + // Validate the credentials + setValidationState("validating"); + setValidationError(null); + + try { + await checkProviderApiKey( + provider.providerType, + apiKey, + accessKey, + region, + ); + if (mountedRef.current) { + setValidationState("valid"); + } + } catch (err) { + if (mountedRef.current) { + setValidationState("invalid"); + setValidationError( + err instanceof Error ? err.message : "Invalid credentials", + ); + } + } + }, [viewState, fieldValues, validationState]); + // Handle disconnect const handleDisconnect = useCallback(async () => { if (viewState.type !== "options") return; @@ -248,6 +328,54 @@ export function ProviderSelector({ setValidationError(null); } } + } else if (viewState.type === "multiInput") { + if (!("fields" in viewState.provider) || !viewState.provider.fields) + return; + const fields = viewState.provider.fields; + const currentField = fields[focusedFieldIndex]; + if (!currentField) return; + + if (key.escape) { + // Back to list + setViewState({ type: "list" }); + setFieldValues({}); + setFocusedFieldIndex(0); + setValidationState("idle"); + setValidationError(null); + } else if (key.tab) { + // Move to next/prev field + if (key.shift) { + setFocusedFieldIndex((prev) => Math.max(0, prev - 1)); + } else { + setFocusedFieldIndex((prev) => Math.min(fields.length - 1, prev + 1)); + } + } else if (key.upArrow) { + setFocusedFieldIndex((prev) => Math.max(0, prev - 1)); + } else if (key.downArrow) { + setFocusedFieldIndex((prev) => Math.min(fields.length - 1, prev + 1)); + } else if (key.return) { + handleMultiFieldValidateAndSave(); + } else if (key.backspace || key.delete) { + setFieldValues((prev) => ({ + ...prev, + [currentField.key]: (prev[currentField.key] || "").slice(0, -1), + })); + // Reset validation if value changed + if (validationState !== "idle") { + setValidationState("idle"); + setValidationError(null); + } + } else if (input && !key.ctrl && !key.meta) { + setFieldValues((prev) => ({ + ...prev, + [currentField.key]: (prev[currentField.key] || "") + input, + })); + // Reset validation if value changed + if (validationState !== "idle") { + setValidationState("idle"); + setValidationError(null); + } + } } else if (viewState.type === "options") { const options = ["Update API key", "Disconnect", "Back"]; if (key.escape) { @@ -389,6 +517,105 @@ export function ProviderSelector({ ); }; + // Render multi-input view (for providers like Bedrock) + const renderMultiInputView = () => { + if (viewState.type !== "multiInput") return null; + if (!("fields" in viewState.provider) || !viewState.provider.fields) + return null; + + const { provider } = viewState; + const fields = provider.fields; + + // Check if all fields are filled + const allFilled = fields.every((field) => fieldValues[field.key]?.trim()); + + const statusText = + validationState === "validating" + ? " (validating...)" + : validationState === "valid" + ? " (credentials validated!)" + : validationState === "invalid" + ? ` (invalid${validationError ? `: ${validationError}` : ""})` + : ""; + + const statusColor = + validationState === "valid" + ? "green" + : validationState === "invalid" + ? "red" + : undefined; + + const footerText = + validationState === "valid" + ? "Enter to save · Esc cancel" + : allFilled + ? "Enter to validate · Tab/↑↓ navigate · Esc cancel" + : "Tab/↑↓ navigate · Esc cancel"; + + return ( + <> + + + Connect {provider.displayName} + + + + + {fields.map((field, index) => { + const isFocused = index === focusedFieldIndex; + const value = fieldValues[field.key] || ""; + const displayValue = field.secret ? maskApiKey(value) : value; + + return ( + + + {isFocused ? "> " : " "} + + + {field.label}: + + + + {displayValue || + (isFocused + ? `(${field.placeholder || "enter value"})` + : "")} + + + ); + })} + + + {(validationState !== "idle" || validationError) && ( + + + {" "} + {statusText} + + + )} + + + + {" "} + {footerText} + + + + ); + }; + // Render options view (for connected providers) const renderOptionsView = () => { if (viewState.type !== "options") return null; @@ -450,6 +677,7 @@ export function ProviderSelector({ {viewState.type === "list" && renderListView()} {viewState.type === "input" && renderInputView()} + {viewState.type === "multiInput" && renderMultiInputView()} {viewState.type === "options" && renderOptionsView()} ); diff --git a/src/providers/byok-providers.ts b/src/providers/byok-providers.ts index 5f2c898..1b04c1a 100644 --- a/src/providers/byok-providers.ts +++ b/src/providers/byok-providers.ts @@ -6,6 +6,14 @@ import { LETTA_CLOUD_API_URL } from "../auth/oauth"; import { settingsManager } from "../settings-manager"; +// Field definition for multi-field providers (like Bedrock) +export interface ProviderField { + key: string; + label: string; + placeholder?: string; + secret?: boolean; // If true, mask input like a password +} + // Provider configuration for the /connect UI export const BYOK_PROVIDERS = [ { @@ -44,6 +52,18 @@ export const BYOK_PROVIDERS = [ providerType: "google_ai", providerName: "lc-gemini", }, + { + id: "bedrock", + displayName: "AWS Bedrock", + description: "Connect to Claude on Amazon Bedrock", + providerType: "bedrock", + providerName: "lc-bedrock", + fields: [ + { key: "accessKey", label: "AWS Access Key ID", placeholder: "AKIA..." }, + { key: "apiKey", label: "AWS Secret Access Key", secret: true }, + { key: "region", label: "AWS Region", placeholder: "us-east-1" }, + ] as ProviderField[], + }, ] as const; export type ByokProviderId = (typeof BYOK_PROVIDERS)[number]["id"]; @@ -56,6 +76,8 @@ export interface ProviderResponse { provider_type: string; api_key?: string; base_url?: string; + access_key?: string; + region?: string; } /** @@ -161,10 +183,14 @@ export async function getProviderByName( export async function checkProviderApiKey( providerType: string, apiKey: string, + accessKey?: string, + region?: string, ): Promise { await providersRequest<{ message: string }>("POST", "/v1/providers/check", { provider_type: providerType, api_key: apiKey, + ...(accessKey && { access_key: accessKey }), + ...(region && { region }), }); } @@ -175,11 +201,15 @@ export async function createProvider( providerType: string, providerName: string, apiKey: string, + accessKey?: string, + region?: string, ): Promise { return providersRequest("POST", "/v1/providers", { name: providerName, provider_type: providerType, api_key: apiKey, + ...(accessKey && { access_key: accessKey }), + ...(region && { region }), }); } @@ -189,12 +219,16 @@ export async function createProvider( export async function updateProvider( providerId: string, apiKey: string, + accessKey?: string, + region?: string, ): Promise { return providersRequest( "PATCH", `/v1/providers/${providerId}`, { api_key: apiKey, + ...(accessKey && { access_key: accessKey }), + ...(region && { region }), }, ); } @@ -214,14 +248,16 @@ export async function createOrUpdateProvider( providerType: string, providerName: string, apiKey: string, + accessKey?: string, + region?: string, ): Promise { const existing = await getProviderByName(providerName); if (existing) { - return updateProvider(existing.id, apiKey); + return updateProvider(existing.id, apiKey, accessKey, region); } - return createProvider(providerType, providerName, apiKey); + return createProvider(providerType, providerName, apiKey, accessKey, region); } /**