Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions experimental/ssh/cmd/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,14 @@ func newConnectCommand() *cobra.Command {
This command establishes an SSH connection to Databricks compute, setting up
the SSH server and handling the connection proxy.

For dedicated clusters:
databricks ssh connect --cluster=<cluster-id>

For serverless compute:
databricks ssh connect --name=<connection-name> [--accelerator=<accelerator>]

` + disclaimer,
}

var clusterID string
var connectionName string
var accelerator string
var proxyMode bool
var ide string
var serverMetadata string
var shutdownDelay time.Duration
var maxClients int
Expand All @@ -42,12 +37,17 @@ For serverless compute:
var liteswap string

cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)")
cmd.Flags().StringVar(&connectionName, "name", "", "Connection name (for serverless compute)")
cmd.Flags().StringVar(&accelerator, "accelerator", "", "GPU accelerator type for serverless compute (GPU_1xA10 or GPU_8xH100)")
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects")
cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients")
cmd.Flags().BoolVar(&autoStartCluster, "auto-start-cluster", true, "Automatically start the cluster if it is not running")

cmd.Flags().StringVar(&connectionName, "name", "", "Connection name (for serverless compute)")
cmd.Flags().MarkHidden("name")
cmd.Flags().StringVar(&accelerator, "accelerator", "", "GPU accelerator type (GPU_1xA10 or GPU_8xH100)")
cmd.Flags().MarkHidden("accelerator")
cmd.Flags().StringVar(&ide, "ide", "", "Open remote IDE window (vscode or cursor)")
cmd.Flags().MarkHidden("ide")

cmd.Flags().BoolVar(&proxyMode, "proxy", false, "ProxyCommand mode")
cmd.Flags().MarkHidden("proxy")
cmd.Flags().StringVar(&serverMetadata, "metadata", "", "Metadata of the running SSH server (format: <user_name>,<port>)")
Expand Down Expand Up @@ -80,7 +80,7 @@ For serverless compute:
wsClient := cmdctx.WorkspaceClient(ctx)

if !proxyMode && clusterID == "" && connectionName == "" {
return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the serverless connection name")
return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)")
}

if accelerator != "" && connectionName == "" {
Expand All @@ -89,7 +89,7 @@ For serverless compute:

// Remove when we add support for serverless CPU
if connectionName != "" && accelerator == "" {
return errors.New("--name flag requires --accelerator to be set (e.g. for now we only support serverless GPU compute)")
return errors.New("--name flag requires --accelerator to be set (for now we only support serverless GPU compute)")
}

// TODO: validate connectionName if provided
Expand All @@ -100,6 +100,7 @@ For serverless compute:
ConnectionName: connectionName,
Accelerator: accelerator,
ProxyMode: proxyMode,
IDE: ide,
ServerMetadata: serverMetadata,
ShutdownDelay: shutdownDelay,
MaxClients: maxClients,
Expand Down
21 changes: 17 additions & 4 deletions experimental/ssh/cmd/setup.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package ssh

import (
"fmt"
"time"

"github.com/databricks/cli/cmd/root"
"github.com/databricks/cli/experimental/ssh/internal/client"
"github.com/databricks/cli/experimental/ssh/internal/setup"
"github.com/databricks/cli/libs/cmdctx"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -43,16 +45,27 @@ an SSH host configuration to your SSH config file.

cmd.RunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
client := cmdctx.WorkspaceClient(ctx)
opts := setup.SetupOptions{
wsClient := cmdctx.WorkspaceClient(ctx)
setupOpts := setup.SetupOptions{
HostName: hostName,
ClusterID: clusterID,
AutoStartCluster: autoStartCluster,
SSHConfigPath: sshConfigPath,
ShutdownDelay: shutdownDelay,
Profile: client.Config.Profile,
Profile: wsClient.Config.Profile,
}
return setup.Setup(ctx, client, opts)
clientOpts := client.ClientOptions{
ClusterID: setupOpts.ClusterID,
AutoStartCluster: setupOpts.AutoStartCluster,
ShutdownDelay: setupOpts.ShutdownDelay,
Profile: setupOpts.Profile,
}
proxyCommand, err := clientOpts.ToProxyCommand()
if err != nil {
return fmt.Errorf("failed to generate ProxyCommand: %w", err)
}
setupOpts.ProxyCommand = proxyCommand
return setup.Setup(ctx, wsClient, setupOpts)
}

return cmd
Expand Down
196 changes: 93 additions & 103 deletions experimental/ssh/internal/client/client.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
package client

import (
"bytes"
"context"
_ "embed"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
Expand All @@ -21,6 +19,7 @@ import (

"github.com/databricks/cli/experimental/ssh/internal/keys"
"github.com/databricks/cli/experimental/ssh/internal/proxy"
"github.com/databricks/cli/experimental/ssh/internal/sshconfig"
sshWorkspace "github.com/databricks/cli/experimental/ssh/internal/workspace"
"github.com/databricks/cli/internal/build"
"github.com/databricks/cli/libs/cmdio"
Expand All @@ -41,6 +40,11 @@ var errServerMetadata = errors.New("server metadata error")
const (
sshServerTaskKey = "start_ssh_server"
serverlessEnvironmentKey = "ssh_tunnel_serverless"

VSCodeOption = "vscode"
VSCodeCommand = "code"
CursorOption = "cursor"
CursorCommand = "cursor"
)

type ClientOptions struct {
Expand All @@ -58,6 +62,8 @@ type ClientOptions struct {
// to the cluster and proxy all traffic through stdin/stdout.
// In the non proxy mode the CLI spawns an ssh client with the ProxyCommand config.
ProxyMode bool
// Open remote IDE window with a specific ssh config (empty, 'vscode', or 'cursor')
IDE string
// Expected format: "<user_name>,<port>,<cluster_id>".
// If present, the CLI won't attempt to start the server.
ServerMetadata string
Expand Down Expand Up @@ -171,8 +177,7 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
}

// Only check cluster state for dedicated clusters
// TODO: we can remove liteswap check when we can start serverless GPU clusters via API.
if !opts.IsServerlessMode() && opts.Liteswap == "" {
if !opts.IsServerlessMode() {
err := checkClusterState(ctx, client, opts.ClusterID, opts.AutoStartCluster)
if err != nil {
return err
Expand Down Expand Up @@ -250,12 +255,88 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt

if opts.ProxyMode {
return runSSHProxy(ctx, client, serverPort, clusterID, opts)
} else if opts.IDE != "" {
return runIDE(ctx, client, userName, keyPath, serverPort, clusterID, opts)
} else {
cmdio.LogString(ctx, fmt.Sprintf("Additional SSH arguments: %v", opts.AdditionalArgs))
return spawnSSHClient(ctx, userName, keyPath, serverPort, clusterID, opts)
}
}

func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error {
if opts.IDE != VSCodeOption && opts.IDE != CursorOption {
return fmt.Errorf("invalid IDE value: %s, expected '%s' or '%s'", opts.IDE, VSCodeOption, CursorOption)
}

connectionName := opts.SessionIdentifier()
if connectionName == "" {
return errors.New("connection name is required for IDE integration")
}

// Get Databricks user name for the workspace path
currentUser, err := client.CurrentUser.Me(ctx)
if err != nil {
return fmt.Errorf("failed to get current user: %w", err)
}
databricksUserName := currentUser.UserName

// Ensure SSH config entry exists
configPath, err := sshconfig.GetMainConfigPath()
if err != nil {
return fmt.Errorf("failed to get SSH config path: %w", err)
}

err = ensureSSHConfigEntry(ctx, configPath, connectionName, userName, keyPath, serverPort, clusterID, opts)
if err != nil {
return fmt.Errorf("failed to ensure SSH config entry: %w", err)
}

ideCommand := VSCodeCommand
if opts.IDE == CursorOption {
ideCommand = CursorCommand
}

// Construct the remote SSH URI
// Format: ssh-remote+<server_user_name>@<connection_name> /Workspace/Users/<databricks_user_name>/
remoteURI := fmt.Sprintf("ssh-remote+%s@%s", userName, connectionName)
remotePath := fmt.Sprintf("/Workspace/Users/%s/", databricksUserName)

cmdio.LogString(ctx, fmt.Sprintf("Launching %s with remote URI: %s and path: %s", opts.IDE, remoteURI, remotePath))

ideCmd := exec.CommandContext(ctx, ideCommand, "--remote", remoteURI, remotePath)
ideCmd.Stdout = os.Stdout
ideCmd.Stderr = os.Stderr

return ideCmd.Run()
}

func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error {
// Ensure the Include directive exists in the main SSH config
err := sshconfig.EnsureIncludeDirective(configPath)
if err != nil {
return err
}

// Generate ProxyCommand with server metadata
optsWithMetadata := opts
optsWithMetadata.ServerMetadata = FormatMetadata(userName, serverPort, clusterID)

proxyCommand, err := optsWithMetadata.ToProxyCommand()
if err != nil {
return fmt.Errorf("failed to generate ProxyCommand: %w", err)
}

hostConfig := sshconfig.GenerateHostConfig(hostName, userName, keyPath, proxyCommand)

_, err = sshconfig.CreateOrUpdateHostConfig(ctx, hostName, hostConfig, true)
if err != nil {
return err
}

cmdio.LogString(ctx, fmt.Sprintf("Updated SSH config entry for '%s'", hostName))
return nil
}

// getServerMetadata retrieves the server metadata from the workspace and validates it via Driver Proxy.
// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless).
// For dedicated clusters, clusterID should be the same as sessionID.
Expand All @@ -265,7 +346,7 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient,
if err != nil {
return 0, "", "", errors.Join(errServerMetadata, err)
}
cmdio.LogString(ctx, "Workspace metadata: "+fmt.Sprintf("%+v", wsMetadata))
log.Debugf(ctx, "Workspace metadata: %+v", wsMetadata)

// For serverless mode, the cluster ID comes from the metadata
effectiveClusterID := clusterID
Expand Down Expand Up @@ -352,11 +433,6 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,

cmdio.LogString(ctx, "Submitting a job to start the ssh server...")

// Use manual HTTP call when hardware_accelerator is needed (SDK doesn't support it yet)
if opts.Accelerator != "" {
return submitSSHTunnelJobManual(ctx, client, jobNotebookPath, baseParams, opts)
}

task := jobs.SubmitTask{
TaskKey: sshServerTaskKey,
NotebookTask: &jobs.NotebookTask{
Expand All @@ -368,6 +444,12 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,

if opts.IsServerlessMode() {
task.EnvironmentKey = serverlessEnvironmentKey
if opts.Accelerator != "" {
cmdio.LogString(ctx, "Using accelerator: "+opts.Accelerator)
task.Compute = &jobs.Compute{
HardwareAccelerator: compute.HardwareAcceleratorType(opts.Accelerator),
}
}
} else {
task.ExistingClusterId = opts.ClusterID
}
Expand Down Expand Up @@ -399,97 +481,6 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient,
return waitForJobToStart(ctx, client, waiter.RunId, opts.TaskStartupTimeout)
}

// submitSSHTunnelJobManual submits a job using manual HTTP call for features not yet supported by the SDK.
// Currently used for hardware_accelerator field which is not yet in the SDK.
func submitSSHTunnelJobManual(ctx context.Context, client *databricks.WorkspaceClient, jobNotebookPath string, baseParams map[string]string, opts ClientOptions) error {
sessionID := opts.SessionIdentifier()
sshTunnelJobName := "ssh-server-bootstrap-" + sessionID

// Construct the request payload manually to allow custom parameters
task := map[string]any{
"task_key": sshServerTaskKey,
"notebook_task": map[string]any{
"notebook_path": jobNotebookPath,
"base_parameters": baseParams,
},
"timeout_seconds": int(opts.ServerTimeout.Seconds()),
}

if opts.IsServerlessMode() {
task["environment_key"] = serverlessEnvironmentKey
if opts.Accelerator != "" {
cmdio.LogString(ctx, "Using accelerator: "+opts.Accelerator)
task["compute"] = map[string]any{
"hardware_accelerator": opts.Accelerator,
}
}
} else {
task["existing_cluster_id"] = opts.ClusterID
}

submitRequest := map[string]any{
"run_name": sshTunnelJobName,
"timeout_seconds": int(opts.ServerTimeout.Seconds()),
"tasks": []map[string]any{task},
}

if opts.IsServerlessMode() {
submitRequest["environments"] = []map[string]any{
{
"environment_key": serverlessEnvironmentKey,
"spec": map[string]any{
"environment_version": "3",
},
},
}
}

requestBody, err := json.Marshal(submitRequest)
if err != nil {
return fmt.Errorf("failed to marshal request body: %w", err)
}

cmdio.LogString(ctx, "Request body: "+string(requestBody))

apiURL := client.Config.Host + "/api/2.1/jobs/runs/submit"
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(requestBody))
if err != nil {
return fmt.Errorf("failed to create request: %w", err)
}

req.Header.Set("Content-Type", "application/json")
if err := client.Config.Authenticate(req); err != nil {
return fmt.Errorf("failed to authenticate request: %w", err)
}

resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("failed to submit job: %w", err)
}
defer resp.Body.Close()

responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response body: %w", err)
}

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to submit job, status code %d: %s", resp.StatusCode, string(responseBody))
}

var result struct {
RunID int64 `json:"run_id"`
}
if err := json.Unmarshal(responseBody, &result); err != nil {
return fmt.Errorf("failed to parse response: %w", err)
}

cmdio.LogString(ctx, fmt.Sprintf("Job submitted successfully with run ID: %d", result.RunID))

// For manual submissions we still need to poll manually
return waitForJobToStart(ctx, client, result.RunID, opts.TaskStartupTimeout)
}

func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, serverPort int, clusterID string, opts ClientOptions) error {
// Create a copy with metadata for the ProxyCommand
optsWithMetadata := opts
Expand All @@ -516,8 +507,7 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server
sshArgs = append(sshArgs, hostName)
sshArgs = append(sshArgs, opts.AdditionalArgs...)

cmdio.LogString(ctx, "Launching SSH client: ssh "+strings.Join(sshArgs, " "))

log.Debugf(ctx, "Launching SSH client: ssh %s", strings.Join(sshArgs, " "))
sshCmd := exec.CommandContext(ctx, "ssh", sshArgs...)

sshCmd.Stdin = os.Stdin
Expand Down
Loading