diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go index 6c04db57e2..4eca1aee7b 100644 --- a/experimental/ssh/cmd/connect.go +++ b/experimental/ssh/cmd/connect.go @@ -19,12 +19,6 @@ func newConnectCommand() *cobra.Command { This command establishes an SSH connection to Databricks compute, setting up the SSH server and handling the connection proxy. -For dedicated clusters: - databricks ssh connect --cluster= - -For serverless compute: - databricks ssh connect --name= [--accelerator=] - ` + disclaimer, } @@ -32,6 +26,7 @@ For serverless compute: var connectionName string var accelerator string var proxyMode bool + var ide string var serverMetadata string var shutdownDelay time.Duration var maxClients int @@ -42,12 +37,17 @@ For serverless compute: var liteswap string cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)") - cmd.Flags().StringVar(&connectionName, "name", "", "Connection name (for serverless compute)") - cmd.Flags().StringVar(&accelerator, "accelerator", "", "GPU accelerator type for serverless compute (GPU_1xA10 or GPU_8xH100)") cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects") cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients") cmd.Flags().BoolVar(&autoStartCluster, "auto-start-cluster", true, "Automatically start the cluster if it is not running") + cmd.Flags().StringVar(&connectionName, "name", "", "Connection name (for serverless compute)") + cmd.Flags().MarkHidden("name") + cmd.Flags().StringVar(&accelerator, "accelerator", "", "GPU accelerator type (GPU_1xA10 or GPU_8xH100)") + cmd.Flags().MarkHidden("accelerator") + cmd.Flags().StringVar(&ide, "ide", "", "Open remote IDE window (vscode or cursor)") + cmd.Flags().MarkHidden("ide") + cmd.Flags().BoolVar(&proxyMode, "proxy", false, "ProxyCommand mode") cmd.Flags().MarkHidden("proxy") cmd.Flags().StringVar(&serverMetadata, "metadata", "", "Metadata of the running SSH server (format: ,)") @@ -80,7 +80,7 @@ For serverless compute: wsClient := cmdctx.WorkspaceClient(ctx) if !proxyMode && clusterID == "" && connectionName == "" { - return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the serverless connection name") + return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)") } if accelerator != "" && connectionName == "" { @@ -89,7 +89,7 @@ For serverless compute: // Remove when we add support for serverless CPU if connectionName != "" && accelerator == "" { - return errors.New("--name flag requires --accelerator to be set (e.g. for now we only support serverless GPU compute)") + return errors.New("--name flag requires --accelerator to be set (for now we only support serverless GPU compute)") } // TODO: validate connectionName if provided @@ -100,6 +100,7 @@ For serverless compute: ConnectionName: connectionName, Accelerator: accelerator, ProxyMode: proxyMode, + IDE: ide, ServerMetadata: serverMetadata, ShutdownDelay: shutdownDelay, MaxClients: maxClients, diff --git a/experimental/ssh/cmd/setup.go b/experimental/ssh/cmd/setup.go index 3e4523904c..81b7863666 100644 --- a/experimental/ssh/cmd/setup.go +++ b/experimental/ssh/cmd/setup.go @@ -1,9 +1,11 @@ package ssh import ( + "fmt" "time" "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/experimental/ssh/internal/client" "github.com/databricks/cli/experimental/ssh/internal/setup" "github.com/databricks/cli/libs/cmdctx" "github.com/spf13/cobra" @@ -43,16 +45,27 @@ an SSH host configuration to your SSH config file. cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - client := cmdctx.WorkspaceClient(ctx) - opts := setup.SetupOptions{ + wsClient := cmdctx.WorkspaceClient(ctx) + setupOpts := setup.SetupOptions{ HostName: hostName, ClusterID: clusterID, AutoStartCluster: autoStartCluster, SSHConfigPath: sshConfigPath, ShutdownDelay: shutdownDelay, - Profile: client.Config.Profile, + Profile: wsClient.Config.Profile, } - return setup.Setup(ctx, client, opts) + clientOpts := client.ClientOptions{ + ClusterID: setupOpts.ClusterID, + AutoStartCluster: setupOpts.AutoStartCluster, + ShutdownDelay: setupOpts.ShutdownDelay, + Profile: setupOpts.Profile, + } + proxyCommand, err := clientOpts.ToProxyCommand() + if err != nil { + return fmt.Errorf("failed to generate ProxyCommand: %w", err) + } + setupOpts.ProxyCommand = proxyCommand + return setup.Setup(ctx, wsClient, setupOpts) } return cmd diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 839705c4ec..940f792f0e 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -1,11 +1,9 @@ package client import ( - "bytes" "context" _ "embed" "encoding/base64" - "encoding/json" "errors" "fmt" "io" @@ -21,6 +19,7 @@ import ( "github.com/databricks/cli/experimental/ssh/internal/keys" "github.com/databricks/cli/experimental/ssh/internal/proxy" + "github.com/databricks/cli/experimental/ssh/internal/sshconfig" sshWorkspace "github.com/databricks/cli/experimental/ssh/internal/workspace" "github.com/databricks/cli/internal/build" "github.com/databricks/cli/libs/cmdio" @@ -41,6 +40,11 @@ var errServerMetadata = errors.New("server metadata error") const ( sshServerTaskKey = "start_ssh_server" serverlessEnvironmentKey = "ssh_tunnel_serverless" + + VSCodeOption = "vscode" + VSCodeCommand = "code" + CursorOption = "cursor" + CursorCommand = "cursor" ) type ClientOptions struct { @@ -58,6 +62,8 @@ type ClientOptions struct { // to the cluster and proxy all traffic through stdin/stdout. // In the non proxy mode the CLI spawns an ssh client with the ProxyCommand config. ProxyMode bool + // Open remote IDE window with a specific ssh config (empty, 'vscode', or 'cursor') + IDE string // Expected format: ",,". // If present, the CLI won't attempt to start the server. ServerMetadata string @@ -171,8 +177,7 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt } // Only check cluster state for dedicated clusters - // TODO: we can remove liteswap check when we can start serverless GPU clusters via API. - if !opts.IsServerlessMode() && opts.Liteswap == "" { + if !opts.IsServerlessMode() { err := checkClusterState(ctx, client, opts.ClusterID, opts.AutoStartCluster) if err != nil { return err @@ -250,12 +255,88 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt if opts.ProxyMode { return runSSHProxy(ctx, client, serverPort, clusterID, opts) + } else if opts.IDE != "" { + return runIDE(ctx, client, userName, keyPath, serverPort, clusterID, opts) } else { cmdio.LogString(ctx, fmt.Sprintf("Additional SSH arguments: %v", opts.AdditionalArgs)) return spawnSSHClient(ctx, userName, keyPath, serverPort, clusterID, opts) } } +func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error { + if opts.IDE != VSCodeOption && opts.IDE != CursorOption { + return fmt.Errorf("invalid IDE value: %s, expected '%s' or '%s'", opts.IDE, VSCodeOption, CursorOption) + } + + connectionName := opts.SessionIdentifier() + if connectionName == "" { + return errors.New("connection name is required for IDE integration") + } + + // Get Databricks user name for the workspace path + currentUser, err := client.CurrentUser.Me(ctx) + if err != nil { + return fmt.Errorf("failed to get current user: %w", err) + } + databricksUserName := currentUser.UserName + + // Ensure SSH config entry exists + configPath, err := sshconfig.GetMainConfigPath() + if err != nil { + return fmt.Errorf("failed to get SSH config path: %w", err) + } + + err = ensureSSHConfigEntry(ctx, configPath, connectionName, userName, keyPath, serverPort, clusterID, opts) + if err != nil { + return fmt.Errorf("failed to ensure SSH config entry: %w", err) + } + + ideCommand := VSCodeCommand + if opts.IDE == CursorOption { + ideCommand = CursorCommand + } + + // Construct the remote SSH URI + // Format: ssh-remote+@ /Workspace/Users// + remoteURI := fmt.Sprintf("ssh-remote+%s@%s", userName, connectionName) + remotePath := fmt.Sprintf("/Workspace/Users/%s/", databricksUserName) + + cmdio.LogString(ctx, fmt.Sprintf("Launching %s with remote URI: %s and path: %s", opts.IDE, remoteURI, remotePath)) + + ideCmd := exec.CommandContext(ctx, ideCommand, "--remote", remoteURI, remotePath) + ideCmd.Stdout = os.Stdout + ideCmd.Stderr = os.Stderr + + return ideCmd.Run() +} + +func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error { + // Ensure the Include directive exists in the main SSH config + err := sshconfig.EnsureIncludeDirective(configPath) + if err != nil { + return err + } + + // Generate ProxyCommand with server metadata + optsWithMetadata := opts + optsWithMetadata.ServerMetadata = FormatMetadata(userName, serverPort, clusterID) + + proxyCommand, err := optsWithMetadata.ToProxyCommand() + if err != nil { + return fmt.Errorf("failed to generate ProxyCommand: %w", err) + } + + hostConfig := sshconfig.GenerateHostConfig(hostName, userName, keyPath, proxyCommand) + + _, err = sshconfig.CreateOrUpdateHostConfig(ctx, hostName, hostConfig, true) + if err != nil { + return err + } + + cmdio.LogString(ctx, fmt.Sprintf("Updated SSH config entry for '%s'", hostName)) + return nil +} + // getServerMetadata retrieves the server metadata from the workspace and validates it via Driver Proxy. // sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). // For dedicated clusters, clusterID should be the same as sessionID. @@ -265,7 +346,7 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, if err != nil { return 0, "", "", errors.Join(errServerMetadata, err) } - cmdio.LogString(ctx, "Workspace metadata: "+fmt.Sprintf("%+v", wsMetadata)) + log.Debugf(ctx, "Workspace metadata: %+v", wsMetadata) // For serverless mode, the cluster ID comes from the metadata effectiveClusterID := clusterID @@ -352,11 +433,6 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, cmdio.LogString(ctx, "Submitting a job to start the ssh server...") - // Use manual HTTP call when hardware_accelerator is needed (SDK doesn't support it yet) - if opts.Accelerator != "" { - return submitSSHTunnelJobManual(ctx, client, jobNotebookPath, baseParams, opts) - } - task := jobs.SubmitTask{ TaskKey: sshServerTaskKey, NotebookTask: &jobs.NotebookTask{ @@ -368,6 +444,12 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, if opts.IsServerlessMode() { task.EnvironmentKey = serverlessEnvironmentKey + if opts.Accelerator != "" { + cmdio.LogString(ctx, "Using accelerator: "+opts.Accelerator) + task.Compute = &jobs.Compute{ + HardwareAccelerator: compute.HardwareAcceleratorType(opts.Accelerator), + } + } } else { task.ExistingClusterId = opts.ClusterID } @@ -399,97 +481,6 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, return waitForJobToStart(ctx, client, waiter.RunId, opts.TaskStartupTimeout) } -// submitSSHTunnelJobManual submits a job using manual HTTP call for features not yet supported by the SDK. -// Currently used for hardware_accelerator field which is not yet in the SDK. -func submitSSHTunnelJobManual(ctx context.Context, client *databricks.WorkspaceClient, jobNotebookPath string, baseParams map[string]string, opts ClientOptions) error { - sessionID := opts.SessionIdentifier() - sshTunnelJobName := "ssh-server-bootstrap-" + sessionID - - // Construct the request payload manually to allow custom parameters - task := map[string]any{ - "task_key": sshServerTaskKey, - "notebook_task": map[string]any{ - "notebook_path": jobNotebookPath, - "base_parameters": baseParams, - }, - "timeout_seconds": int(opts.ServerTimeout.Seconds()), - } - - if opts.IsServerlessMode() { - task["environment_key"] = serverlessEnvironmentKey - if opts.Accelerator != "" { - cmdio.LogString(ctx, "Using accelerator: "+opts.Accelerator) - task["compute"] = map[string]any{ - "hardware_accelerator": opts.Accelerator, - } - } - } else { - task["existing_cluster_id"] = opts.ClusterID - } - - submitRequest := map[string]any{ - "run_name": sshTunnelJobName, - "timeout_seconds": int(opts.ServerTimeout.Seconds()), - "tasks": []map[string]any{task}, - } - - if opts.IsServerlessMode() { - submitRequest["environments"] = []map[string]any{ - { - "environment_key": serverlessEnvironmentKey, - "spec": map[string]any{ - "environment_version": "3", - }, - }, - } - } - - requestBody, err := json.Marshal(submitRequest) - if err != nil { - return fmt.Errorf("failed to marshal request body: %w", err) - } - - cmdio.LogString(ctx, "Request body: "+string(requestBody)) - - apiURL := client.Config.Host + "/api/2.1/jobs/runs/submit" - req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(requestBody)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - if err := client.Config.Authenticate(req); err != nil { - return fmt.Errorf("failed to authenticate request: %w", err) - } - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return fmt.Errorf("failed to submit job: %w", err) - } - defer resp.Body.Close() - - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response body: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("failed to submit job, status code %d: %s", resp.StatusCode, string(responseBody)) - } - - var result struct { - RunID int64 `json:"run_id"` - } - if err := json.Unmarshal(responseBody, &result); err != nil { - return fmt.Errorf("failed to parse response: %w", err) - } - - cmdio.LogString(ctx, fmt.Sprintf("Job submitted successfully with run ID: %d", result.RunID)) - - // For manual submissions we still need to poll manually - return waitForJobToStart(ctx, client, result.RunID, opts.TaskStartupTimeout) -} - func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, serverPort int, clusterID string, opts ClientOptions) error { // Create a copy with metadata for the ProxyCommand optsWithMetadata := opts @@ -516,8 +507,7 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server sshArgs = append(sshArgs, hostName) sshArgs = append(sshArgs, opts.AdditionalArgs...) - cmdio.LogString(ctx, "Launching SSH client: ssh "+strings.Join(sshArgs, " ")) - + log.Debugf(ctx, "Launching SSH client: ssh %s", strings.Join(sshArgs, " ")) sshCmd := exec.CommandContext(ctx, "ssh", sshArgs...) sshCmd.Stdin = os.Stdin diff --git a/experimental/ssh/internal/keys/secrets.go b/experimental/ssh/internal/keys/secrets.go index eac692f235..d4e00d10ba 100644 --- a/experimental/ssh/internal/keys/secrets.go +++ b/experimental/ssh/internal/keys/secrets.go @@ -18,10 +18,23 @@ func CreateKeysSecretScope(ctx context.Context, client *databricks.WorkspaceClie return "", fmt.Errorf("failed to get current user: %w", err) } secretScopeName := fmt.Sprintf("%s-%s-ssh-tunnel-keys", me.UserName, sessionID) + + // Do not create the scope if it already exists. + // We can instead filter out "resource already exists" errors from CreateScope, + // but that API can also lead to "limit exceeded" errors, even if the scope does actually exist. + scope, err := client.Secrets.ListSecretsByScope(ctx, secretScopeName) + if err != nil && !errors.Is(err, databricks.ErrResourceDoesNotExist) { + return "", fmt.Errorf("failed to check if secret scope %s exists: %w", secretScopeName, err) + } + + if scope != nil && err == nil { + return secretScopeName, nil + } + err = client.Secrets.CreateScope(ctx, workspace.CreateScope{ Scope: secretScopeName, }) - if err != nil && !errors.Is(err, databricks.ErrResourceAlreadyExists) { + if err != nil { return "", fmt.Errorf("failed to create secrets scope: %w", err) } return secretScopeName, nil diff --git a/experimental/ssh/internal/server/sshd.go b/experimental/ssh/internal/server/sshd.go index 7a038e73b5..c8f23d02a5 100644 --- a/experimental/ssh/internal/server/sshd.go +++ b/experimental/ssh/internal/server/sshd.go @@ -52,15 +52,13 @@ func prepareSSHDConfig(ctx context.Context, client *databricks.WorkspaceClient, return "", err } - // Set all available env vars, wrapping values in quotes and escaping quotes inside values + // Set all available env vars, wrapping values in quotes, escaping quotes, and stripping newlines setEnv := "SetEnv" for _, env := range os.Environ() { parts := strings.SplitN(env, "=", 2) - if len(parts) != 2 { - continue + if len(parts) == 2 { + setEnv += " " + parts[0] + "=\"" + escapeEnvValue(parts[1]) + "\"" } - valEscaped := strings.ReplaceAll(parts[1], "\"", "\\\"") - setEnv += " " + parts[0] + "=\"" + valEscaped + "\"" } setEnv += " DATABRICKS_CLI_UPSTREAM=databricks_ssh_tunnel" setEnv += " DATABRICKS_CLI_UPSTREAM_VERSION=" + opts.Version @@ -94,3 +92,13 @@ func prepareSSHDConfig(ctx context.Context, client *databricks.WorkspaceClient, func createSSHDProcess(ctx context.Context, configPath string) *exec.Cmd { return exec.CommandContext(ctx, "/usr/sbin/sshd", "-f", configPath, "-i") } + +// escapeEnvValue escapes a value for use in sshd SetEnv directive. +// It strips newlines and escapes backslashes and quotes. +func escapeEnvValue(val string) string { + val = strings.ReplaceAll(val, "\r", "") + val = strings.ReplaceAll(val, "\n", "") + val = strings.ReplaceAll(val, "\\", "\\\\") + val = strings.ReplaceAll(val, "\"", "\\\"") + return val +} diff --git a/experimental/ssh/internal/server/sshd_test.go b/experimental/ssh/internal/server/sshd_test.go new file mode 100644 index 0000000000..a453d987a0 --- /dev/null +++ b/experimental/ssh/internal/server/sshd_test.go @@ -0,0 +1,73 @@ +package server + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEscapeEnvValue(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple value", + input: "hello", + expected: "hello", + }, + { + name: "value with quotes", + input: `say "hello"`, + expected: `say \"hello\"`, + }, + { + name: "value with newline", + input: "line1\nline2", + expected: "line1line2", + }, + { + name: "value with carriage return", + input: "line1\rline2", + expected: "line1line2", + }, + { + name: "value with CRLF", + input: "line1\r\nline2", + expected: "line1line2", + }, + { + name: "value with quotes and newlines", + input: "say \"hello\"\nworld", + expected: `say \"hello\"world`, + }, + { + name: "empty value", + input: "", + expected: "", + }, + { + name: "only newlines", + input: "\n\r\n", + expected: "", + }, + { + name: "backslashes", + input: `foo\bar\`, + expected: `foo\\bar\\`, + }, + { + name: "backslash before quote", + input: `foo\"bar`, + expected: `foo\\\"bar`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := escapeEnvValue(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/experimental/ssh/internal/setup/setup.go b/experimental/ssh/internal/setup/setup.go index adfe204427..99b5a68902 100644 --- a/experimental/ssh/internal/setup/setup.go +++ b/experimental/ssh/internal/setup/setup.go @@ -4,14 +4,10 @@ import ( "context" "errors" "fmt" - "os" - "path/filepath" - "regexp" - "strings" "time" - "github.com/databricks/cli/experimental/ssh/internal/client" "github.com/databricks/cli/experimental/ssh/internal/keys" + "github.com/databricks/cli/experimental/ssh/internal/sshconfig" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/compute" @@ -32,6 +28,8 @@ type SetupOptions struct { SSHKeysDir string // Optional auth profile name. If present, will be added as --profile flag to the ProxyCommand Profile string + // Proxy command to use for the SSH connection + ProxyCommand string } func validateClusterAccess(ctx context.Context, client *databricks.WorkspaceClient, clusterID string) error { @@ -45,108 +43,16 @@ func validateClusterAccess(ctx context.Context, client *databricks.WorkspaceClie return nil } -func resolveConfigPath(configPath string) (string, error) { - if configPath != "" { - return configPath, nil - } - homeDir, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("failed to get home directory: %w", err) - } - return filepath.Join(homeDir, ".ssh", "config"), nil -} - func generateHostConfig(opts SetupOptions) (string, error) { identityFilePath, err := keys.GetLocalSSHKeyPath(opts.ClusterID, opts.SSHKeysDir) if err != nil { return "", fmt.Errorf("failed to get local keys folder: %w", err) } - clientOpts := client.ClientOptions{ - ClusterID: opts.ClusterID, - AutoStartCluster: opts.AutoStartCluster, - ShutdownDelay: opts.ShutdownDelay, - Profile: opts.Profile, - } - proxyCommand, err := clientOpts.ToProxyCommand() - if err != nil { - return "", fmt.Errorf("failed to generate ProxyCommand: %w", err) - } - - hostConfig := fmt.Sprintf(` -Host %s - User root - ConnectTimeout 360 - StrictHostKeyChecking accept-new - IdentitiesOnly yes - IdentityFile %q - ProxyCommand %s -`, opts.HostName, identityFilePath, proxyCommand) - + hostConfig := sshconfig.GenerateHostConfig(opts.HostName, "root", identityFilePath, opts.ProxyCommand) return hostConfig, nil } -func ensureSSHConfigExists(configPath string) error { - _, err := os.Stat(configPath) - if os.IsNotExist(err) { - sshDir := filepath.Dir(configPath) - err = os.MkdirAll(sshDir, 0o700) - if err != nil { - return fmt.Errorf("failed to create SSH directory: %w", err) - } - err = os.WriteFile(configPath, []byte(""), 0o600) - if err != nil { - return fmt.Errorf("failed to create SSH config file: %w", err) - } - return nil - } else if err != nil { - return fmt.Errorf("failed to check SSH config file: %w", err) - } - return nil -} - -func checkExistingHosts(content []byte, hostName string) (bool, error) { - existingContent := string(content) - pattern := fmt.Sprintf(`(?m)^\s*Host\s+%s\s*$`, regexp.QuoteMeta(hostName)) - matched, err := regexp.MatchString(pattern, existingContent) - if err != nil { - return false, fmt.Errorf("failed to check for existing host: %w", err) - } - if matched { - return true, nil - } - return false, nil -} - -func createBackup(content []byte, configPath string) (string, error) { - backupPath := configPath + ".bak" - err := os.WriteFile(backupPath, content, 0o600) - if err != nil { - return backupPath, fmt.Errorf("failed to create backup of SSH config file: %w", err) - } - return backupPath, nil -} - -func updateSSHConfigFile(configPath, hostConfig, hostName string) error { - content, err := os.ReadFile(configPath) - if err != nil { - return fmt.Errorf("failed to read SSH config file: %w", err) - } - - existingContent := string(content) - if !strings.HasSuffix(existingContent, "\n") && existingContent != "" { - existingContent += "\n" - } - newContent := existingContent + hostConfig - - err = os.WriteFile(configPath, []byte(newContent), 0o600) - if err != nil { - return fmt.Errorf("failed to update SSH config file: %w", err) - } - - return nil -} - func clusterSelectionPrompt(ctx context.Context, client *databricks.WorkspaceClient) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading clusters.") @@ -184,50 +90,51 @@ func Setup(ctx context.Context, client *databricks.WorkspaceClient, opts SetupOp return err } - configPath, err := resolveConfigPath(opts.SSHConfigPath) + configPath, err := sshconfig.GetMainConfigPathOrDefault(opts.SSHConfigPath) if err != nil { return err } - hostConfig, err := generateHostConfig(opts) + err = sshconfig.EnsureIncludeDirective(configPath) if err != nil { return err } - err = ensureSSHConfigExists(configPath) + hostConfig, err := generateHostConfig(opts) if err != nil { return err } - existingContent, err := os.ReadFile(configPath) + exists, err := sshconfig.HostConfigExists(opts.HostName) if err != nil { - return fmt.Errorf("failed to read SSH config file: %w", err) + return err } - if len(existingContent) > 0 { - exists, err := checkExistingHosts(existingContent, opts.HostName) + recreate := false + if exists { + recreate, err = sshconfig.PromptRecreateConfig(ctx, opts.HostName) if err != nil { return err } - if exists { - cmdio.LogString(ctx, fmt.Sprintf("Host '%s' already exists in the SSH config, skipping setup", opts.HostName)) + if !recreate { + cmdio.LogString(ctx, fmt.Sprintf("Skipping setup for host '%s'", opts.HostName)) return nil } - backupPath, err := createBackup(existingContent, configPath) - if err != nil { - return err - } - cmdio.LogString(ctx, "Created backup of existing SSH config at "+backupPath) } cmdio.LogString(ctx, "Adding new entry to the SSH config:\n"+hostConfig) - err = updateSSHConfigFile(configPath, hostConfig, opts.HostName) + _, err = sshconfig.CreateOrUpdateHostConfig(ctx, opts.HostName, hostConfig, recreate) + if err != nil { + return err + } + + hostConfigPath, err := sshconfig.GetHostConfigPath(opts.HostName) if err != nil { return err } - cmdio.LogString(ctx, fmt.Sprintf("Updated SSH config file at %s with '%s' host", configPath, opts.HostName)) + cmdio.LogString(ctx, fmt.Sprintf("Created SSH config file at %s for '%s' host", hostConfigPath, opts.HostName)) cmdio.LogString(ctx, fmt.Sprintf("You can now connect to the cluster using 'ssh %s' terminal command, or use remote capabilities of your IDE", opts.HostName)) return nil } diff --git a/experimental/ssh/internal/setup/setup_test.go b/experimental/ssh/internal/setup/setup_test.go index aa803dfe1c..975828a3c8 100644 --- a/experimental/ssh/internal/setup/setup_test.go +++ b/experimental/ssh/internal/setup/setup_test.go @@ -118,15 +118,24 @@ func TestGenerateProxyCommand_ServerlessModeWithAccelerator(t *testing.T) { } func TestGenerateHostConfig_Valid(t *testing.T) { - // Create a temporary directory for testing tmpDir := t.TempDir() + clientOpts := client.ClientOptions{ + ClusterID: "cluster-123", + AutoStartCluster: true, + ShutdownDelay: 30 * time.Second, + Profile: "test-profile", + } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + opts := SetupOptions{ HostName: "test-host", ClusterID: "cluster-123", SSHKeysDir: tmpDir, ShutdownDelay: 30 * time.Second, Profile: "test-profile", + ProxyCommand: proxyCommand, } result, err := generateHostConfig(opts) @@ -139,29 +148,35 @@ func TestGenerateHostConfig_Valid(t *testing.T) { assert.Contains(t, result, "--shutdown-delay=30s") assert.Contains(t, result, "--profile=test-profile") - // Check that identity file path is included expectedKeyPath := filepath.Join(tmpDir, "cluster-123") assert.Contains(t, result, fmt.Sprintf(`IdentityFile %q`, expectedKeyPath)) } func TestGenerateHostConfig_WithoutProfile(t *testing.T) { - // Create a temporary directory for testing tmpDir := t.TempDir() + clientOpts := client.ClientOptions{ + ClusterID: "cluster-123", + AutoStartCluster: true, + ShutdownDelay: 30 * time.Second, + Profile: "", + } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + opts := SetupOptions{ HostName: "test-host", ClusterID: "cluster-123", SSHKeysDir: tmpDir, ShutdownDelay: 30 * time.Second, - Profile: "", // No profile + Profile: "", + ProxyCommand: proxyCommand, } result, err := generateHostConfig(opts) assert.NoError(t, err) - // Should not contain profile option assert.NotContains(t, result, "--profile=") - // But should contain other elements assert.Contains(t, result, "Host test-host") assert.Contains(t, result, "--cluster=cluster-123") } @@ -187,181 +202,12 @@ func TestGenerateHostConfig_PathEscaping(t *testing.T) { assert.Contains(t, result, fmt.Sprintf(`IdentityFile %q`, expectedPath)) } -func TestEnsureSSHConfigExists(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, ".ssh", "config") - - err := ensureSSHConfigExists(configPath) - assert.NoError(t, err) - - // Check that directory was created - _, err = os.Stat(filepath.Dir(configPath)) - assert.NoError(t, err) - - // Check that file was created - _, err = os.Stat(configPath) - assert.NoError(t, err) - - // Check that file is empty - content, err := os.ReadFile(configPath) - assert.NoError(t, err) - assert.Empty(t, content) -} - -func TestCheckExistingHosts_NoExistingHost(t *testing.T) { - content := []byte(`Host other-host - User root - HostName example.com - -Host another-host - User admin -`) - exists, err := checkExistingHosts(content, "test-host") - assert.NoError(t, err) - assert.False(t, exists) -} - -func TestCheckExistingHosts_HostAlreadyExists(t *testing.T) { - content := []byte(`Host test-host - User root - HostName example.com - -Host another-host - User admin -`) - exists, err := checkExistingHosts(content, "another-host") - assert.NoError(t, err) - assert.True(t, exists) -} - -func TestCheckExistingHosts_EmptyContent(t *testing.T) { - content := []byte("") - exists, err := checkExistingHosts(content, "test-host") - assert.NoError(t, err) - assert.False(t, exists) -} - -func TestCheckExistingHosts_HostNameWithWhitespaces(t *testing.T) { - content := []byte(` Host test-host `) - exists, err := checkExistingHosts(content, "test-host") - assert.NoError(t, err) - assert.True(t, exists) -} - -func TestCheckExistingHosts_PartialNameMatch(t *testing.T) { - content := []byte(`Host test-host-long`) - exists, err := checkExistingHosts(content, "test-host") - assert.NoError(t, err) - assert.False(t, exists) -} - -func TestCreateBackup_CreatesBackupSuccessfully(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - content := []byte("original content") - - backupPath, err := createBackup(content, configPath) - assert.NoError(t, err) - assert.Equal(t, configPath+".bak", backupPath) - - // Check that backup file was created with correct content - backupContent, err := os.ReadFile(backupPath) - assert.NoError(t, err) - assert.Equal(t, content, backupContent) -} - -func TestCreateBackup_OverwritesExistingBackup(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - backupPath := configPath + ".bak" - - // Create existing backup - oldContent := []byte("old backup") - err := os.WriteFile(backupPath, oldContent, 0o644) - require.NoError(t, err) - - // Create new backup - newContent := []byte("new content") - resultPath, err := createBackup(newContent, configPath) - assert.NoError(t, err) - assert.Equal(t, backupPath, resultPath) - - // Check that backup was overwritten - backupContent, err := os.ReadFile(backupPath) - assert.NoError(t, err) - assert.Equal(t, newContent, backupContent) -} - -func TestUpdateSSHConfigFile_UpdatesSuccessfully(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - - // Create initial config file - initialContent := "# SSH Config\nHost existing\n User root\n" - err := os.WriteFile(configPath, []byte(initialContent), 0o600) - require.NoError(t, err) - - hostConfig := "\nHost new-host\n User root\n HostName example.com\n" - err = updateSSHConfigFile(configPath, hostConfig, "new-host") - assert.NoError(t, err) - - // Check that content was appended - finalContent, err := os.ReadFile(configPath) - assert.NoError(t, err) - expected := initialContent + hostConfig - assert.Equal(t, expected, string(finalContent)) -} - -func TestUpdateSSHConfigFile_AddsNewlineIfMissing(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - - // Create config file without trailing newline - initialContent := "Host existing\n User root" - err := os.WriteFile(configPath, []byte(initialContent), 0o600) - require.NoError(t, err) - - hostConfig := "\nHost new-host\n User root\n" - err = updateSSHConfigFile(configPath, hostConfig, "new-host") - assert.NoError(t, err) - - // Check that newline was added before the new content - finalContent, err := os.ReadFile(configPath) - assert.NoError(t, err) - expected := initialContent + "\n" + hostConfig - assert.Equal(t, expected, string(finalContent)) -} - -func TestUpdateSSHConfigFile_HandlesEmptyFile(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - - // Create empty config file - err := os.WriteFile(configPath, []byte(""), 0o600) - require.NoError(t, err) - - hostConfig := "Host new-host\n User root\n" - err = updateSSHConfigFile(configPath, hostConfig, "new-host") - assert.NoError(t, err) - - // Check that content was added without extra newlines - finalContent, err := os.ReadFile(configPath) - assert.NoError(t, err) - assert.Equal(t, hostConfig, string(finalContent)) -} - -func TestUpdateSSHConfigFile_HandlesReadError(t *testing.T) { - configPath := "/nonexistent/file" - hostConfig := "Host new-host\n" - - err := updateSSHConfigFile(configPath, hostConfig, "new-host") - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to read SSH config file") -} - func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) { ctx := cmdio.MockDiscard(context.Background()) tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + configPath := filepath.Join(tmpDir, "ssh_config") m := mocks.NewMockWorkspaceClient(t) @@ -380,22 +226,43 @@ func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) { Profile: "test-profile", } - err := Setup(ctx, m.WorkspaceClient, opts) + clientOpts := client.ClientOptions{ + ClusterID: opts.ClusterID, + AutoStartCluster: opts.AutoStartCluster, + ShutdownDelay: opts.ShutdownDelay, + Profile: opts.Profile, + } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + opts.ProxyCommand = proxyCommand + + err = Setup(ctx, m.WorkspaceClient, opts) assert.NoError(t, err) - // Check that config file was created + // Check that main config has Include directive content, err := os.ReadFile(configPath) assert.NoError(t, err) - configStr := string(content) - assert.Contains(t, configStr, "Host test-host") - assert.Contains(t, configStr, "--cluster=cluster-123") - assert.Contains(t, configStr, "--profile=test-profile") + assert.Contains(t, configStr, "Include") + // SSH config uses forward slashes on all platforms + assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") + + // Check that host config file was created + hostConfigPath := filepath.Join(tmpDir, ".databricks", "ssh-tunnel-configs", "test-host") + hostContent, err := os.ReadFile(hostConfigPath) + assert.NoError(t, err) + hostConfigStr := string(hostContent) + assert.Contains(t, hostConfigStr, "Host test-host") + assert.Contains(t, hostConfigStr, "--cluster=cluster-123") + assert.Contains(t, hostConfigStr, "--profile=test-profile") } func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) { ctx := cmdio.MockDiscard(context.Background()) tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + configPath := filepath.Join(tmpDir, "ssh_config") // Create existing config file @@ -418,54 +285,34 @@ func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) { ShutdownDelay: 60 * time.Second, } - err = Setup(ctx, m.WorkspaceClient, opts) - assert.NoError(t, err) - - // Check that config file was updated and backup was created - content, err := os.ReadFile(configPath) - assert.NoError(t, err) - - configStr := string(content) - assert.Contains(t, configStr, "# Existing SSH Config") // Original content preserved - assert.Contains(t, configStr, "Host new-host") // New content added - assert.Contains(t, configStr, "--cluster=cluster-456") - - // Check backup was created - backupPath := configPath + ".bak" - backupContent, err := os.ReadFile(backupPath) - assert.NoError(t, err) - assert.Equal(t, existingContent, string(backupContent)) -} - -func TestSetup_DoesNotOverrideExistingHost(t *testing.T) { - ctx := cmdio.MockDiscard(context.Background()) - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "ssh_config") - - // Create config file with existing host - existingContent := "Host duplicate-host\n User root\n" - err := os.WriteFile(configPath, []byte(existingContent), 0o600) - require.NoError(t, err) - - m := mocks.NewMockWorkspaceClient(t) - clustersAPI := m.GetMockClustersAPI() - - clustersAPI.EXPECT().Get(ctx, compute.GetClusterRequest{ClusterId: "cluster-123"}).Return(&compute.ClusterDetails{ - DataSecurityMode: compute.DataSecurityModeSingleUser, - }, nil) - - opts := SetupOptions{ - HostName: "duplicate-host", // Same as existing - ClusterID: "cluster-123", - SSHConfigPath: configPath, - SSHKeysDir: tmpDir, - ShutdownDelay: 30 * time.Second, + clientOpts := client.ClientOptions{ + ClusterID: opts.ClusterID, + AutoStartCluster: opts.AutoStartCluster, + ShutdownDelay: opts.ShutdownDelay, + Profile: opts.Profile, } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + opts.ProxyCommand = proxyCommand err = Setup(ctx, m.WorkspaceClient, opts) assert.NoError(t, err) + // Check that main config has Include directive and preserves existing content content, err := os.ReadFile(configPath) assert.NoError(t, err) - assert.Equal(t, "Host duplicate-host\n User root\n", string(content)) + configStr := string(content) + assert.Contains(t, configStr, "Include") + // SSH config uses forward slashes on all platforms + assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") + assert.Contains(t, configStr, "# Existing SSH Config") + assert.Contains(t, configStr, "Host existing-host") + + // Check that host config file was created + hostConfigPath := filepath.Join(tmpDir, ".databricks", "ssh-tunnel-configs", "new-host") + hostContent, err := os.ReadFile(hostConfigPath) + assert.NoError(t, err) + hostConfigStr := string(hostContent) + assert.Contains(t, hostConfigStr, "Host new-host") + assert.Contains(t, hostConfigStr, "--cluster=cluster-456") } diff --git a/experimental/ssh/internal/sshconfig/sshconfig.go b/experimental/ssh/internal/sshconfig/sshconfig.go new file mode 100644 index 0000000000..3a6713acbf --- /dev/null +++ b/experimental/ssh/internal/sshconfig/sshconfig.go @@ -0,0 +1,172 @@ +package sshconfig + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/databricks/cli/libs/cmdio" +) + +const ( + // configDirName is the directory name for Databricks SSH tunnel configs, relative to the user's home directory. + configDirName = ".databricks/ssh-tunnel-configs" +) + +func GetConfigDir() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get home directory: %w", err) + } + return filepath.Join(homeDir, configDirName), nil +} + +func GetMainConfigPath() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get home directory: %w", err) + } + return filepath.Join(homeDir, ".ssh", "config"), nil +} + +func GetMainConfigPathOrDefault(configPath string) (string, error) { + if configPath != "" { + return configPath, nil + } + return GetMainConfigPath() +} + +func EnsureMainConfigExists(configPath string) error { + _, err := os.Stat(configPath) + if os.IsNotExist(err) { + sshDir := filepath.Dir(configPath) + err = os.MkdirAll(sshDir, 0o700) + if err != nil { + return fmt.Errorf("failed to create SSH directory: %w", err) + } + err = os.WriteFile(configPath, []byte(""), 0o600) + if err != nil { + return fmt.Errorf("failed to create SSH config file: %w", err) + } + return nil + } + return err +} + +func EnsureIncludeDirective(configPath string) error { + configDir, err := GetConfigDir() + if err != nil { + return err + } + + err = os.MkdirAll(configDir, 0o700) + if err != nil { + return fmt.Errorf("failed to create Databricks SSH config directory: %w", err) + } + + err = EnsureMainConfigExists(configPath) + if err != nil { + return err + } + + content, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("failed to read SSH config file: %w", err) + } + + // Convert path to forward slashes for SSH config compatibility across platforms + configDirUnix := filepath.ToSlash(configDir) + + includeLine := fmt.Sprintf("Include %s/*", configDirUnix) + if strings.Contains(string(content), includeLine) { + return nil + } + + newContent := includeLine + "\n" + if len(content) > 0 && !strings.HasPrefix(string(content), "\n") { + newContent += "\n" + } + newContent += string(content) + + err = os.WriteFile(configPath, []byte(newContent), 0o600) + if err != nil { + return fmt.Errorf("failed to update SSH config file with Include directive: %w", err) + } + + return nil +} + +func GetHostConfigPath(hostName string) (string, error) { + configDir, err := GetConfigDir() + if err != nil { + return "", err + } + return filepath.Join(configDir, hostName), nil +} + +func HostConfigExists(hostName string) (bool, error) { + configPath, err := GetHostConfigPath(hostName) + if err != nil { + return false, err + } + _, err = os.Stat(configPath) + if os.IsNotExist(err) { + return false, nil + } + if err != nil { + return false, fmt.Errorf("failed to check host config file: %w", err) + } + return true, nil +} + +// Returns true if the config was created/updated, false if it was skipped. +func CreateOrUpdateHostConfig(ctx context.Context, hostName, hostConfig string, recreate bool) (bool, error) { + configPath, err := GetHostConfigPath(hostName) + if err != nil { + return false, err + } + + exists, err := HostConfigExists(hostName) + if err != nil { + return false, err + } + + if exists && !recreate { + return false, nil + } + + configDir := filepath.Dir(configPath) + err = os.MkdirAll(configDir, 0o700) + if err != nil { + return false, fmt.Errorf("failed to create config directory: %w", err) + } + + err = os.WriteFile(configPath, []byte(hostConfig), 0o600) + if err != nil { + return false, fmt.Errorf("failed to write host config file: %w", err) + } + + return true, nil +} + +func PromptRecreateConfig(ctx context.Context, hostName string) (bool, error) { + response, err := cmdio.AskYesOrNo(ctx, fmt.Sprintf("Host '%s' already exists. Do you want to recreate the config?", hostName)) + if err != nil { + return false, err + } + return response, nil +} + +func GenerateHostConfig(hostName, userName, identityFile, proxyCommand string) string { + return fmt.Sprintf(` +Host %s + User %s + ConnectTimeout 360 + StrictHostKeyChecking accept-new + IdentitiesOnly yes + IdentityFile %q + ProxyCommand %s +`, hostName, userName, identityFile, proxyCommand) +} diff --git a/experimental/ssh/internal/sshconfig/sshconfig_test.go b/experimental/ssh/internal/sshconfig/sshconfig_test.go new file mode 100644 index 0000000000..5fa13923ee --- /dev/null +++ b/experimental/ssh/internal/sshconfig/sshconfig_test.go @@ -0,0 +1,223 @@ +package sshconfig + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/databricks/cli/libs/cmdio" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetConfigDir(t *testing.T) { + dir, err := GetConfigDir() + assert.NoError(t, err) + assert.Contains(t, dir, filepath.Join(".databricks", "ssh-tunnel-configs")) +} + +func TestGetMainConfigPath(t *testing.T) { + path, err := GetMainConfigPath() + assert.NoError(t, err) + assert.Contains(t, path, filepath.Join(".ssh", "config")) +} + +func TestGetMainConfigPathOrDefault(t *testing.T) { + path, err := GetMainConfigPathOrDefault("/custom/path") + assert.NoError(t, err) + assert.Equal(t, "/custom/path", path) + + path, err = GetMainConfigPathOrDefault("") + assert.NoError(t, err) + assert.Contains(t, path, filepath.Join(".ssh", "config")) +} + +func TestEnsureMainConfigExists(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".ssh", "config") + + err := EnsureMainConfigExists(configPath) + assert.NoError(t, err) + + _, err = os.Stat(filepath.Dir(configPath)) + assert.NoError(t, err) + + _, err = os.Stat(configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Empty(t, content) +} + +func TestEnsureIncludeDirective_NewConfig(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".ssh", "config") + + // Set home directory for test + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + err := EnsureIncludeDirective(configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + + configStr := string(content) + assert.Contains(t, configStr, "Include") + // SSH config uses forward slashes on all platforms + assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") +} + +func TestEnsureIncludeDirective_AlreadyExists(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configPath := filepath.Join(tmpDir, ".ssh", "config") + + configDir, err := GetConfigDir() + require.NoError(t, err) + + // Use forward slashes as that's what SSH config uses + configDirUnix := filepath.ToSlash(configDir) + existingContent := "Include " + configDirUnix + "/*\n\nHost example\n User test\n" + err = os.MkdirAll(filepath.Dir(configPath), 0o700) + require.NoError(t, err) + err = os.WriteFile(configPath, []byte(existingContent), 0o600) + require.NoError(t, err) + + err = EnsureIncludeDirective(configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Equal(t, existingContent, string(content)) +} + +func TestEnsureIncludeDirective_PrependsToExisting(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".ssh", "config") + + // Set home directory for test + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + existingContent := "Host example\n User test\n" + err := os.MkdirAll(filepath.Dir(configPath), 0o700) + require.NoError(t, err) + err = os.WriteFile(configPath, []byte(existingContent), 0o600) + require.NoError(t, err) + + err = EnsureIncludeDirective(configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + + configStr := string(content) + assert.Contains(t, configStr, "Include") + // SSH config uses forward slashes on all platforms + assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") + assert.Contains(t, configStr, "Host example") + + includeIndex := len("Include") + hostIndex := len(configStr) - len(existingContent) + assert.Less(t, includeIndex, hostIndex, "Include directive should come before existing content") +} + +func TestGetHostConfigPath(t *testing.T) { + path, err := GetHostConfigPath("test-host") + assert.NoError(t, err) + assert.Contains(t, path, filepath.Join(".databricks", "ssh-tunnel-configs", "test-host")) +} + +func TestHostConfigExists(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + exists, err := HostConfigExists("nonexistent") + assert.NoError(t, err) + assert.False(t, exists) + + configDir := filepath.Join(tmpDir, configDirName) + err = os.MkdirAll(configDir, 0o700) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(configDir, "existing-host"), []byte("config"), 0o600) + require.NoError(t, err) + + exists, err = HostConfigExists("existing-host") + assert.NoError(t, err) + assert.True(t, exists) +} + +func TestCreateOrUpdateHostConfig_NewConfig(t *testing.T) { + ctx := cmdio.MockDiscard(context.Background()) + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + hostConfig := "Host test\n User root\n" + created, err := CreateOrUpdateHostConfig(ctx, "test-host", hostConfig, false) + assert.NoError(t, err) + assert.True(t, created) + + configPath, err := GetHostConfigPath("test-host") + require.NoError(t, err) + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Equal(t, hostConfig, string(content)) +} + +func TestCreateOrUpdateHostConfig_ExistingConfigNoRecreate(t *testing.T) { + ctx := cmdio.MockDiscard(context.Background()) + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configDir := filepath.Join(tmpDir, configDirName) + err := os.MkdirAll(configDir, 0o700) + require.NoError(t, err) + existingConfig := "Host test\n User admin\n" + err = os.WriteFile(filepath.Join(configDir, "test-host"), []byte(existingConfig), 0o600) + require.NoError(t, err) + + newConfig := "Host test\n User root\n" + created, err := CreateOrUpdateHostConfig(ctx, "test-host", newConfig, false) + assert.NoError(t, err) + assert.False(t, created) + + configPath, err := GetHostConfigPath("test-host") + require.NoError(t, err) + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Equal(t, existingConfig, string(content)) +} + +func TestCreateOrUpdateHostConfig_ExistingConfigWithRecreate(t *testing.T) { + ctx := cmdio.MockDiscard(context.Background()) + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configDir := filepath.Join(tmpDir, configDirName) + err := os.MkdirAll(configDir, 0o700) + require.NoError(t, err) + existingConfig := "Host test\n User admin\n" + err = os.WriteFile(filepath.Join(configDir, "test-host"), []byte(existingConfig), 0o600) + require.NoError(t, err) + + newConfig := "Host test\n User root\n" + created, err := CreateOrUpdateHostConfig(ctx, "test-host", newConfig, true) + assert.NoError(t, err) + assert.True(t, created) + + configPath, err := GetHostConfigPath("test-host") + require.NoError(t, err) + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Equal(t, newConfig, string(content)) +}