From a4d274e60fc0b9ebb2e20f3a646fc9f5d13e276d Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Mon, 19 Jan 2026 15:23:10 +0100 Subject: [PATCH 1/6] Add IDE flag and support for IDE integration for vscode and cursor --- experimental/ssh/cmd/connect.go | 11 +- experimental/ssh/cmd/setup.go | 21 +++- experimental/ssh/internal/client/client.go | 138 ++++++++++++++++++++- experimental/ssh/internal/setup/setup.go | 16 +-- 4 files changed, 163 insertions(+), 23 deletions(-) diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go index 6c04db57e2..2d9102dbbd 100644 --- a/experimental/ssh/cmd/connect.go +++ b/experimental/ssh/cmd/connect.go @@ -32,6 +32,7 @@ For serverless compute: var connectionName string var accelerator string var proxyMode bool + var ide string var serverMetadata string var shutdownDelay time.Duration var maxClients int @@ -42,8 +43,9 @@ For serverless compute: var liteswap string cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)") - cmd.Flags().StringVar(&connectionName, "name", "", "Connection name (for serverless compute)") - cmd.Flags().StringVar(&accelerator, "accelerator", "", "GPU accelerator type for serverless compute (GPU_1xA10 or GPU_8xH100)") + cmd.Flags().StringVar(&connectionName, "name", "", "Connection name") + cmd.Flags().StringVar(&accelerator, "accelerator", "", "GPU accelerator type (GPU_1xA10 or GPU_8xH100)") + cmd.Flags().StringVar(&ide, "ide", "", "Open remote IDE window (vscode or cursor)") cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects") cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients") cmd.Flags().BoolVar(&autoStartCluster, "auto-start-cluster", true, "Automatically start the cluster if it is not running") @@ -80,7 +82,7 @@ For serverless compute: wsClient := cmdctx.WorkspaceClient(ctx) if !proxyMode && clusterID == "" && connectionName == "" { - return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the serverless connection name") + return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)") } if accelerator != "" && connectionName == "" { @@ -89,7 +91,7 @@ For serverless compute: // Remove when we add support for serverless CPU if connectionName != "" && accelerator == "" { - return errors.New("--name flag requires --accelerator to be set (e.g. for now we only support serverless GPU compute)") + return errors.New("--name flag requires --accelerator to be set (for now we only support serverless GPU compute)") } // TODO: validate connectionName if provided @@ -100,6 +102,7 @@ For serverless compute: ConnectionName: connectionName, Accelerator: accelerator, ProxyMode: proxyMode, + IDE: ide, ServerMetadata: serverMetadata, ShutdownDelay: shutdownDelay, MaxClients: maxClients, diff --git a/experimental/ssh/cmd/setup.go b/experimental/ssh/cmd/setup.go index 3e4523904c..81b7863666 100644 --- a/experimental/ssh/cmd/setup.go +++ b/experimental/ssh/cmd/setup.go @@ -1,9 +1,11 @@ package ssh import ( + "fmt" "time" "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/experimental/ssh/internal/client" "github.com/databricks/cli/experimental/ssh/internal/setup" "github.com/databricks/cli/libs/cmdctx" "github.com/spf13/cobra" @@ -43,16 +45,27 @@ an SSH host configuration to your SSH config file. cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - client := cmdctx.WorkspaceClient(ctx) - opts := setup.SetupOptions{ + wsClient := cmdctx.WorkspaceClient(ctx) + setupOpts := setup.SetupOptions{ HostName: hostName, ClusterID: clusterID, AutoStartCluster: autoStartCluster, SSHConfigPath: sshConfigPath, ShutdownDelay: shutdownDelay, - Profile: client.Config.Profile, + Profile: wsClient.Config.Profile, } - return setup.Setup(ctx, client, opts) + clientOpts := client.ClientOptions{ + ClusterID: setupOpts.ClusterID, + AutoStartCluster: setupOpts.AutoStartCluster, + ShutdownDelay: setupOpts.ShutdownDelay, + Profile: setupOpts.Profile, + } + proxyCommand, err := clientOpts.ToProxyCommand() + if err != nil { + return fmt.Errorf("failed to generate ProxyCommand: %w", err) + } + setupOpts.ProxyCommand = proxyCommand + return setup.Setup(ctx, wsClient, setupOpts) } return cmd diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 839705c4ec..23572bf323 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -14,6 +14,7 @@ import ( "os/exec" "os/signal" "path/filepath" + "regexp" "strconv" "strings" "syscall" @@ -58,6 +59,8 @@ type ClientOptions struct { // to the cluster and proxy all traffic through stdin/stdout. // In the non proxy mode the CLI spawns an ssh client with the ProxyCommand config. ProxyMode bool + // Open remote IDE window with a specific ssh config (empty, 'vscode', or 'cursor') + IDE string // Expected format: ",,". // If present, the CLI won't attempt to start the server. ServerMetadata string @@ -171,8 +174,7 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt } // Only check cluster state for dedicated clusters - // TODO: we can remove liteswap check when we can start serverless GPU clusters via API. - if !opts.IsServerlessMode() && opts.Liteswap == "" { + if !opts.IsServerlessMode() { err := checkClusterState(ctx, client, opts.ClusterID, opts.AutoStartCluster) if err != nil { return err @@ -250,12 +252,144 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt if opts.ProxyMode { return runSSHProxy(ctx, client, serverPort, clusterID, opts) + } else if opts.IDE != "" { + return runIDE(ctx, client, userName, keyPath, serverPort, clusterID, opts) } else { cmdio.LogString(ctx, fmt.Sprintf("Additional SSH arguments: %v", opts.AdditionalArgs)) return spawnSSHClient(ctx, userName, keyPath, serverPort, clusterID, opts) } } +func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error { + // Validate IDE value + if opts.IDE != "vscode" && opts.IDE != "cursor" { + return fmt.Errorf("invalid IDE value: %s, expected 'vscode' or 'cursor'", opts.IDE) + } + + // Get connection name + connectionName := opts.SessionIdentifier() + if connectionName == "" { + return errors.New("connection name is required for IDE integration") + } + + // Get Databricks user name for the workspace path + currentUser, err := client.CurrentUser.Me(ctx) + if err != nil { + return fmt.Errorf("failed to get current user: %w", err) + } + databricksUserName := currentUser.UserName + + // Ensure SSH config entry exists + configPath, err := getSSHConfigPath() + if err != nil { + return fmt.Errorf("failed to get SSH config path: %w", err) + } + + err = ensureSSHConfigEntry(ctx, configPath, connectionName, userName, keyPath, serverPort, clusterID, opts) + if err != nil { + return fmt.Errorf("failed to ensure SSH config entry: %w", err) + } + + // Determine the IDE command + ideCommand := "code" + if opts.IDE == "cursor" { + ideCommand = "cursor" + } + + // Construct the remote SSH URI + // Format: ssh-remote+@ /Workspace/Users// + remoteURI := fmt.Sprintf("ssh-remote+%s@%s", userName, connectionName) + remotePath := fmt.Sprintf("/Workspace/Users/%s/", databricksUserName) + + cmdio.LogString(ctx, fmt.Sprintf("Launching %s with remote URI: %s and path: %s", opts.IDE, remoteURI, remotePath)) + + // Launch the IDE + ideCmd := exec.CommandContext(ctx, ideCommand, "--remote", remoteURI, remotePath) + ideCmd.Stdout = os.Stdout + ideCmd.Stderr = os.Stderr + + return ideCmd.Run() +} + +func getSSHConfigPath() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get home directory: %w", err) + } + return filepath.Join(homeDir, ".ssh", "config"), nil +} + +func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error { + // Ensure SSH directory and config file exist + sshDir := filepath.Dir(configPath) + err := os.MkdirAll(sshDir, 0o700) + if err != nil { + return fmt.Errorf("failed to create SSH directory: %w", err) + } + + _, err = os.Stat(configPath) + if os.IsNotExist(err) { + err = os.WriteFile(configPath, []byte(""), 0o600) + if err != nil { + return fmt.Errorf("failed to create SSH config file: %w", err) + } + } else if err != nil { + return fmt.Errorf("failed to check SSH config file: %w", err) + } + + // Check if the host entry already exists + existingContent, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("failed to read SSH config file: %w", err) + } + + hostPattern := fmt.Sprintf(`(?m)^\s*Host\s+%s\s*$`, regexp.QuoteMeta(hostName)) + matched, err := regexp.Match(hostPattern, existingContent) + if err != nil { + return fmt.Errorf("failed to check for existing host: %w", err) + } + + if matched { + cmdio.LogString(ctx, fmt.Sprintf("SSH config entry for '%s' already exists", hostName)) + return nil + } + + // Generate ProxyCommand with server metadata + optsWithMetadata := opts + optsWithMetadata.ServerMetadata = FormatMetadata(userName, serverPort, clusterID) + + proxyCommand, err := optsWithMetadata.ToProxyCommand() + if err != nil { + return fmt.Errorf("failed to generate ProxyCommand: %w", err) + } + + // Generate host config + hostConfig := fmt.Sprintf(` +Host %s + User %s + ConnectTimeout 360 + StrictHostKeyChecking accept-new + IdentitiesOnly yes + IdentityFile %q + ProxyCommand %s +`, hostName, userName, keyPath, proxyCommand) + + // Append to config file + content := string(existingContent) + if !strings.HasSuffix(content, "\n") && content != "" { + content += "\n" + } + content += hostConfig + + err = os.WriteFile(configPath, []byte(content), 0o600) + if err != nil { + return fmt.Errorf("failed to update SSH config file: %w", err) + } + + cmdio.LogString(ctx, fmt.Sprintf("Added SSH config entry for '%s'", hostName)) + return nil +} + // getServerMetadata retrieves the server metadata from the workspace and validates it via Driver Proxy. // sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). // For dedicated clusters, clusterID should be the same as sessionID. diff --git a/experimental/ssh/internal/setup/setup.go b/experimental/ssh/internal/setup/setup.go index adfe204427..0d76071a65 100644 --- a/experimental/ssh/internal/setup/setup.go +++ b/experimental/ssh/internal/setup/setup.go @@ -10,7 +10,6 @@ import ( "strings" "time" - "github.com/databricks/cli/experimental/ssh/internal/client" "github.com/databricks/cli/experimental/ssh/internal/keys" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/databricks-sdk-go" @@ -32,6 +31,8 @@ type SetupOptions struct { SSHKeysDir string // Optional auth profile name. If present, will be added as --profile flag to the ProxyCommand Profile string + // Proxy command to use for the SSH connection + ProxyCommand string } func validateClusterAccess(ctx context.Context, client *databricks.WorkspaceClient, clusterID string) error { @@ -62,17 +63,6 @@ func generateHostConfig(opts SetupOptions) (string, error) { return "", fmt.Errorf("failed to get local keys folder: %w", err) } - clientOpts := client.ClientOptions{ - ClusterID: opts.ClusterID, - AutoStartCluster: opts.AutoStartCluster, - ShutdownDelay: opts.ShutdownDelay, - Profile: opts.Profile, - } - proxyCommand, err := clientOpts.ToProxyCommand() - if err != nil { - return "", fmt.Errorf("failed to generate ProxyCommand: %w", err) - } - hostConfig := fmt.Sprintf(` Host %s User root @@ -81,7 +71,7 @@ Host %s IdentitiesOnly yes IdentityFile %q ProxyCommand %s -`, opts.HostName, identityFilePath, proxyCommand) +`, opts.HostName, identityFilePath, opts.ProxyCommand) return hostConfig, nil } From a0009d8951b8bd5c70e00178f503a232363aeb76 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Thu, 5 Feb 2026 17:16:33 +0100 Subject: [PATCH 2/6] Mark hidden flags for serverless compute, remove manual Http call --- experimental/ssh/cmd/connect.go | 10 +- experimental/ssh/internal/client/client.go | 104 ++------------------- 2 files changed, 13 insertions(+), 101 deletions(-) diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go index 2d9102dbbd..87876d85ef 100644 --- a/experimental/ssh/cmd/connect.go +++ b/experimental/ssh/cmd/connect.go @@ -43,13 +43,17 @@ For serverless compute: var liteswap string cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)") - cmd.Flags().StringVar(&connectionName, "name", "", "Connection name") - cmd.Flags().StringVar(&accelerator, "accelerator", "", "GPU accelerator type (GPU_1xA10 or GPU_8xH100)") - cmd.Flags().StringVar(&ide, "ide", "", "Open remote IDE window (vscode or cursor)") cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects") cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients") cmd.Flags().BoolVar(&autoStartCluster, "auto-start-cluster", true, "Automatically start the cluster if it is not running") + cmd.Flags().StringVar(&connectionName, "name", "", "Connection name (for serverless compute)") + cmd.Flags().MarkHidden("name") + cmd.Flags().StringVar(&accelerator, "accelerator", "", "GPU accelerator type (GPU_1xA10 or GPU_8xH100)") + cmd.Flags().MarkHidden("accelerator") + cmd.Flags().StringVar(&ide, "ide", "", "Open remote IDE window (vscode or cursor)") + cmd.Flags().MarkHidden("ide") + cmd.Flags().BoolVar(&proxyMode, "proxy", false, "ProxyCommand mode") cmd.Flags().MarkHidden("proxy") cmd.Flags().StringVar(&serverMetadata, "metadata", "", "Metadata of the running SSH server (format: ,)") diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 23572bf323..2fcc08f373 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -1,11 +1,9 @@ package client import ( - "bytes" "context" _ "embed" "encoding/base64" - "encoding/json" "errors" "fmt" "io" @@ -486,11 +484,6 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, cmdio.LogString(ctx, "Submitting a job to start the ssh server...") - // Use manual HTTP call when hardware_accelerator is needed (SDK doesn't support it yet) - if opts.Accelerator != "" { - return submitSSHTunnelJobManual(ctx, client, jobNotebookPath, baseParams, opts) - } - task := jobs.SubmitTask{ TaskKey: sshServerTaskKey, NotebookTask: &jobs.NotebookTask{ @@ -502,6 +495,12 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, if opts.IsServerlessMode() { task.EnvironmentKey = serverlessEnvironmentKey + if opts.Accelerator != "" { + cmdio.LogString(ctx, "Using accelerator: "+opts.Accelerator) + task.Compute = &jobs.Compute{ + HardwareAccelerator: compute.HardwareAcceleratorType(opts.Accelerator), + } + } } else { task.ExistingClusterId = opts.ClusterID } @@ -533,97 +532,6 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, return waitForJobToStart(ctx, client, waiter.RunId, opts.TaskStartupTimeout) } -// submitSSHTunnelJobManual submits a job using manual HTTP call for features not yet supported by the SDK. -// Currently used for hardware_accelerator field which is not yet in the SDK. -func submitSSHTunnelJobManual(ctx context.Context, client *databricks.WorkspaceClient, jobNotebookPath string, baseParams map[string]string, opts ClientOptions) error { - sessionID := opts.SessionIdentifier() - sshTunnelJobName := "ssh-server-bootstrap-" + sessionID - - // Construct the request payload manually to allow custom parameters - task := map[string]any{ - "task_key": sshServerTaskKey, - "notebook_task": map[string]any{ - "notebook_path": jobNotebookPath, - "base_parameters": baseParams, - }, - "timeout_seconds": int(opts.ServerTimeout.Seconds()), - } - - if opts.IsServerlessMode() { - task["environment_key"] = serverlessEnvironmentKey - if opts.Accelerator != "" { - cmdio.LogString(ctx, "Using accelerator: "+opts.Accelerator) - task["compute"] = map[string]any{ - "hardware_accelerator": opts.Accelerator, - } - } - } else { - task["existing_cluster_id"] = opts.ClusterID - } - - submitRequest := map[string]any{ - "run_name": sshTunnelJobName, - "timeout_seconds": int(opts.ServerTimeout.Seconds()), - "tasks": []map[string]any{task}, - } - - if opts.IsServerlessMode() { - submitRequest["environments"] = []map[string]any{ - { - "environment_key": serverlessEnvironmentKey, - "spec": map[string]any{ - "environment_version": "3", - }, - }, - } - } - - requestBody, err := json.Marshal(submitRequest) - if err != nil { - return fmt.Errorf("failed to marshal request body: %w", err) - } - - cmdio.LogString(ctx, "Request body: "+string(requestBody)) - - apiURL := client.Config.Host + "/api/2.1/jobs/runs/submit" - req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(requestBody)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - if err := client.Config.Authenticate(req); err != nil { - return fmt.Errorf("failed to authenticate request: %w", err) - } - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return fmt.Errorf("failed to submit job: %w", err) - } - defer resp.Body.Close() - - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response body: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("failed to submit job, status code %d: %s", resp.StatusCode, string(responseBody)) - } - - var result struct { - RunID int64 `json:"run_id"` - } - if err := json.Unmarshal(responseBody, &result); err != nil { - return fmt.Errorf("failed to parse response: %w", err) - } - - cmdio.LogString(ctx, fmt.Sprintf("Job submitted successfully with run ID: %d", result.RunID)) - - // For manual submissions we still need to poll manually - return waitForJobToStart(ctx, client, result.RunID, opts.TaskStartupTimeout) -} - func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, serverPort int, clusterID string, opts ClientOptions) error { // Create a copy with metadata for the ProxyCommand optsWithMetadata := opts From 51b8b2686ed9e38823a69a56170db3e07e6a6636 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Fri, 6 Feb 2026 14:38:58 +0100 Subject: [PATCH 3/6] Separate IDE options into constants --- experimental/ssh/internal/client/client.go | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 2fcc08f373..6c4695b6ac 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -40,6 +40,11 @@ var errServerMetadata = errors.New("server metadata error") const ( sshServerTaskKey = "start_ssh_server" serverlessEnvironmentKey = "ssh_tunnel_serverless" + + VSCodeOption = "vscode" + VSCodeCommand = "code" + CursorOption = "cursor" + CursorCommand = "cursor" ) type ClientOptions struct { @@ -259,12 +264,10 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt } func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error { - // Validate IDE value - if opts.IDE != "vscode" && opts.IDE != "cursor" { - return fmt.Errorf("invalid IDE value: %s, expected 'vscode' or 'cursor'", opts.IDE) + if opts.IDE != VSCodeOption && opts.IDE != CursorOption { + return fmt.Errorf("invalid IDE value: %s, expected '%s' or '%s'", opts.IDE, VSCodeOption, CursorOption) } - // Get connection name connectionName := opts.SessionIdentifier() if connectionName == "" { return errors.New("connection name is required for IDE integration") @@ -288,10 +291,9 @@ func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, k return fmt.Errorf("failed to ensure SSH config entry: %w", err) } - // Determine the IDE command - ideCommand := "code" - if opts.IDE == "cursor" { - ideCommand = "cursor" + ideCommand := VSCodeCommand + if opts.IDE == CursorOption { + ideCommand = CursorCommand } // Construct the remote SSH URI @@ -301,7 +303,6 @@ func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, k cmdio.LogString(ctx, fmt.Sprintf("Launching %s with remote URI: %s and path: %s", opts.IDE, remoteURI, remotePath)) - // Launch the IDE ideCmd := exec.CommandContext(ctx, ideCommand, "--remote", remoteURI, remotePath) ideCmd.Stdout = os.Stdout ideCmd.Stderr = os.Stderr From abb3157927b16003b6a5458ebc8c568f6419c0aa Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Tue, 10 Feb 2026 13:15:48 +0100 Subject: [PATCH 4/6] Move unnecessary logging to debug level --- experimental/ssh/internal/client/client.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 6c4695b6ac..6b3e6b6382 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -398,7 +398,7 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, if err != nil { return 0, "", "", errors.Join(errServerMetadata, err) } - cmdio.LogString(ctx, "Workspace metadata: "+fmt.Sprintf("%+v", wsMetadata)) + log.Debugf(ctx, "Workspace metadata: %+v", wsMetadata) // For serverless mode, the cluster ID comes from the metadata effectiveClusterID := clusterID @@ -559,8 +559,7 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server sshArgs = append(sshArgs, hostName) sshArgs = append(sshArgs, opts.AdditionalArgs...) - cmdio.LogString(ctx, "Launching SSH client: ssh "+strings.Join(sshArgs, " ")) - + log.Debugf(ctx, "Launching SSH client: ssh %s", strings.Join(sshArgs, " ")) sshCmd := exec.CommandContext(ctx, "ssh", sshArgs...) sshCmd.Stdin = os.Stdin From 627868cd771f031e491bce90f6166ecfeb94d187 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Thu, 12 Feb 2026 15:17:00 +0100 Subject: [PATCH 5/6] Remove dedicated cluster and serverless compute specific documentation from connect command help message --- experimental/ssh/cmd/connect.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go index 87876d85ef..4eca1aee7b 100644 --- a/experimental/ssh/cmd/connect.go +++ b/experimental/ssh/cmd/connect.go @@ -19,12 +19,6 @@ func newConnectCommand() *cobra.Command { This command establishes an SSH connection to Databricks compute, setting up the SSH server and handling the connection proxy. -For dedicated clusters: - databricks ssh connect --cluster= - -For serverless compute: - databricks ssh connect --name= [--accelerator=] - ` + disclaimer, } From ec55ba2b979a3a197838ccd96dae2f09b0e19f53 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Mon, 16 Feb 2026 11:11:05 +0100 Subject: [PATCH 6/6] Simplify ssh config management (#4453) ## Changes Simplify sshd config management, so we can easily re-use it. The main change is that we add one "include" directive to the system ssh config, and all new host configs are encapsulated in single files, which are easy to replace or add. Here we also solve a separate logic of proper escaping env vars, as before we were letting new lines sneak into SetEnv directive, which is not allowed Based on https://github.com/databricks/cli/pull/4452 ## Why ## Tests --- experimental/ssh/internal/client/client.go | 70 +--- experimental/ssh/internal/keys/secrets.go | 15 +- experimental/ssh/internal/server/sshd.go | 18 +- experimental/ssh/internal/server/sshd_test.go | 73 +++++ experimental/ssh/internal/setup/setup.go | 121 ++----- experimental/ssh/internal/setup/setup_test.go | 301 +++++------------- .../ssh/internal/sshconfig/sshconfig.go | 172 ++++++++++ .../ssh/internal/sshconfig/sshconfig_test.go | 223 +++++++++++++ 8 files changed, 597 insertions(+), 396 deletions(-) create mode 100644 experimental/ssh/internal/server/sshd_test.go create mode 100644 experimental/ssh/internal/sshconfig/sshconfig.go create mode 100644 experimental/ssh/internal/sshconfig/sshconfig_test.go diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 6b3e6b6382..940f792f0e 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -12,7 +12,6 @@ import ( "os/exec" "os/signal" "path/filepath" - "regexp" "strconv" "strings" "syscall" @@ -20,6 +19,7 @@ import ( "github.com/databricks/cli/experimental/ssh/internal/keys" "github.com/databricks/cli/experimental/ssh/internal/proxy" + "github.com/databricks/cli/experimental/ssh/internal/sshconfig" sshWorkspace "github.com/databricks/cli/experimental/ssh/internal/workspace" "github.com/databricks/cli/internal/build" "github.com/databricks/cli/libs/cmdio" @@ -281,7 +281,7 @@ func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, k databricksUserName := currentUser.UserName // Ensure SSH config entry exists - configPath, err := getSSHConfigPath() + configPath, err := sshconfig.GetMainConfigPath() if err != nil { return fmt.Errorf("failed to get SSH config path: %w", err) } @@ -310,47 +310,11 @@ func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, k return ideCmd.Run() } -func getSSHConfigPath() (string, error) { - homeDir, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("failed to get home directory: %w", err) - } - return filepath.Join(homeDir, ".ssh", "config"), nil -} - func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error { - // Ensure SSH directory and config file exist - sshDir := filepath.Dir(configPath) - err := os.MkdirAll(sshDir, 0o700) + // Ensure the Include directive exists in the main SSH config + err := sshconfig.EnsureIncludeDirective(configPath) if err != nil { - return fmt.Errorf("failed to create SSH directory: %w", err) - } - - _, err = os.Stat(configPath) - if os.IsNotExist(err) { - err = os.WriteFile(configPath, []byte(""), 0o600) - if err != nil { - return fmt.Errorf("failed to create SSH config file: %w", err) - } - } else if err != nil { - return fmt.Errorf("failed to check SSH config file: %w", err) - } - - // Check if the host entry already exists - existingContent, err := os.ReadFile(configPath) - if err != nil { - return fmt.Errorf("failed to read SSH config file: %w", err) - } - - hostPattern := fmt.Sprintf(`(?m)^\s*Host\s+%s\s*$`, regexp.QuoteMeta(hostName)) - matched, err := regexp.Match(hostPattern, existingContent) - if err != nil { - return fmt.Errorf("failed to check for existing host: %w", err) - } - - if matched { - cmdio.LogString(ctx, fmt.Sprintf("SSH config entry for '%s' already exists", hostName)) - return nil + return err } // Generate ProxyCommand with server metadata @@ -362,30 +326,14 @@ func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, k return fmt.Errorf("failed to generate ProxyCommand: %w", err) } - // Generate host config - hostConfig := fmt.Sprintf(` -Host %s - User %s - ConnectTimeout 360 - StrictHostKeyChecking accept-new - IdentitiesOnly yes - IdentityFile %q - ProxyCommand %s -`, hostName, userName, keyPath, proxyCommand) - - // Append to config file - content := string(existingContent) - if !strings.HasSuffix(content, "\n") && content != "" { - content += "\n" - } - content += hostConfig + hostConfig := sshconfig.GenerateHostConfig(hostName, userName, keyPath, proxyCommand) - err = os.WriteFile(configPath, []byte(content), 0o600) + _, err = sshconfig.CreateOrUpdateHostConfig(ctx, hostName, hostConfig, true) if err != nil { - return fmt.Errorf("failed to update SSH config file: %w", err) + return err } - cmdio.LogString(ctx, fmt.Sprintf("Added SSH config entry for '%s'", hostName)) + cmdio.LogString(ctx, fmt.Sprintf("Updated SSH config entry for '%s'", hostName)) return nil } diff --git a/experimental/ssh/internal/keys/secrets.go b/experimental/ssh/internal/keys/secrets.go index eac692f235..d4e00d10ba 100644 --- a/experimental/ssh/internal/keys/secrets.go +++ b/experimental/ssh/internal/keys/secrets.go @@ -18,10 +18,23 @@ func CreateKeysSecretScope(ctx context.Context, client *databricks.WorkspaceClie return "", fmt.Errorf("failed to get current user: %w", err) } secretScopeName := fmt.Sprintf("%s-%s-ssh-tunnel-keys", me.UserName, sessionID) + + // Do not create the scope if it already exists. + // We can instead filter out "resource already exists" errors from CreateScope, + // but that API can also lead to "limit exceeded" errors, even if the scope does actually exist. + scope, err := client.Secrets.ListSecretsByScope(ctx, secretScopeName) + if err != nil && !errors.Is(err, databricks.ErrResourceDoesNotExist) { + return "", fmt.Errorf("failed to check if secret scope %s exists: %w", secretScopeName, err) + } + + if scope != nil && err == nil { + return secretScopeName, nil + } + err = client.Secrets.CreateScope(ctx, workspace.CreateScope{ Scope: secretScopeName, }) - if err != nil && !errors.Is(err, databricks.ErrResourceAlreadyExists) { + if err != nil { return "", fmt.Errorf("failed to create secrets scope: %w", err) } return secretScopeName, nil diff --git a/experimental/ssh/internal/server/sshd.go b/experimental/ssh/internal/server/sshd.go index 7a038e73b5..c8f23d02a5 100644 --- a/experimental/ssh/internal/server/sshd.go +++ b/experimental/ssh/internal/server/sshd.go @@ -52,15 +52,13 @@ func prepareSSHDConfig(ctx context.Context, client *databricks.WorkspaceClient, return "", err } - // Set all available env vars, wrapping values in quotes and escaping quotes inside values + // Set all available env vars, wrapping values in quotes, escaping quotes, and stripping newlines setEnv := "SetEnv" for _, env := range os.Environ() { parts := strings.SplitN(env, "=", 2) - if len(parts) != 2 { - continue + if len(parts) == 2 { + setEnv += " " + parts[0] + "=\"" + escapeEnvValue(parts[1]) + "\"" } - valEscaped := strings.ReplaceAll(parts[1], "\"", "\\\"") - setEnv += " " + parts[0] + "=\"" + valEscaped + "\"" } setEnv += " DATABRICKS_CLI_UPSTREAM=databricks_ssh_tunnel" setEnv += " DATABRICKS_CLI_UPSTREAM_VERSION=" + opts.Version @@ -94,3 +92,13 @@ func prepareSSHDConfig(ctx context.Context, client *databricks.WorkspaceClient, func createSSHDProcess(ctx context.Context, configPath string) *exec.Cmd { return exec.CommandContext(ctx, "/usr/sbin/sshd", "-f", configPath, "-i") } + +// escapeEnvValue escapes a value for use in sshd SetEnv directive. +// It strips newlines and escapes backslashes and quotes. +func escapeEnvValue(val string) string { + val = strings.ReplaceAll(val, "\r", "") + val = strings.ReplaceAll(val, "\n", "") + val = strings.ReplaceAll(val, "\\", "\\\\") + val = strings.ReplaceAll(val, "\"", "\\\"") + return val +} diff --git a/experimental/ssh/internal/server/sshd_test.go b/experimental/ssh/internal/server/sshd_test.go new file mode 100644 index 0000000000..a453d987a0 --- /dev/null +++ b/experimental/ssh/internal/server/sshd_test.go @@ -0,0 +1,73 @@ +package server + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEscapeEnvValue(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple value", + input: "hello", + expected: "hello", + }, + { + name: "value with quotes", + input: `say "hello"`, + expected: `say \"hello\"`, + }, + { + name: "value with newline", + input: "line1\nline2", + expected: "line1line2", + }, + { + name: "value with carriage return", + input: "line1\rline2", + expected: "line1line2", + }, + { + name: "value with CRLF", + input: "line1\r\nline2", + expected: "line1line2", + }, + { + name: "value with quotes and newlines", + input: "say \"hello\"\nworld", + expected: `say \"hello\"world`, + }, + { + name: "empty value", + input: "", + expected: "", + }, + { + name: "only newlines", + input: "\n\r\n", + expected: "", + }, + { + name: "backslashes", + input: `foo\bar\`, + expected: `foo\\bar\\`, + }, + { + name: "backslash before quote", + input: `foo\"bar`, + expected: `foo\\\"bar`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := escapeEnvValue(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/experimental/ssh/internal/setup/setup.go b/experimental/ssh/internal/setup/setup.go index 0d76071a65..99b5a68902 100644 --- a/experimental/ssh/internal/setup/setup.go +++ b/experimental/ssh/internal/setup/setup.go @@ -4,13 +4,10 @@ import ( "context" "errors" "fmt" - "os" - "path/filepath" - "regexp" - "strings" "time" "github.com/databricks/cli/experimental/ssh/internal/keys" + "github.com/databricks/cli/experimental/ssh/internal/sshconfig" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/compute" @@ -46,97 +43,16 @@ func validateClusterAccess(ctx context.Context, client *databricks.WorkspaceClie return nil } -func resolveConfigPath(configPath string) (string, error) { - if configPath != "" { - return configPath, nil - } - homeDir, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("failed to get home directory: %w", err) - } - return filepath.Join(homeDir, ".ssh", "config"), nil -} - func generateHostConfig(opts SetupOptions) (string, error) { identityFilePath, err := keys.GetLocalSSHKeyPath(opts.ClusterID, opts.SSHKeysDir) if err != nil { return "", fmt.Errorf("failed to get local keys folder: %w", err) } - hostConfig := fmt.Sprintf(` -Host %s - User root - ConnectTimeout 360 - StrictHostKeyChecking accept-new - IdentitiesOnly yes - IdentityFile %q - ProxyCommand %s -`, opts.HostName, identityFilePath, opts.ProxyCommand) - + hostConfig := sshconfig.GenerateHostConfig(opts.HostName, "root", identityFilePath, opts.ProxyCommand) return hostConfig, nil } -func ensureSSHConfigExists(configPath string) error { - _, err := os.Stat(configPath) - if os.IsNotExist(err) { - sshDir := filepath.Dir(configPath) - err = os.MkdirAll(sshDir, 0o700) - if err != nil { - return fmt.Errorf("failed to create SSH directory: %w", err) - } - err = os.WriteFile(configPath, []byte(""), 0o600) - if err != nil { - return fmt.Errorf("failed to create SSH config file: %w", err) - } - return nil - } else if err != nil { - return fmt.Errorf("failed to check SSH config file: %w", err) - } - return nil -} - -func checkExistingHosts(content []byte, hostName string) (bool, error) { - existingContent := string(content) - pattern := fmt.Sprintf(`(?m)^\s*Host\s+%s\s*$`, regexp.QuoteMeta(hostName)) - matched, err := regexp.MatchString(pattern, existingContent) - if err != nil { - return false, fmt.Errorf("failed to check for existing host: %w", err) - } - if matched { - return true, nil - } - return false, nil -} - -func createBackup(content []byte, configPath string) (string, error) { - backupPath := configPath + ".bak" - err := os.WriteFile(backupPath, content, 0o600) - if err != nil { - return backupPath, fmt.Errorf("failed to create backup of SSH config file: %w", err) - } - return backupPath, nil -} - -func updateSSHConfigFile(configPath, hostConfig, hostName string) error { - content, err := os.ReadFile(configPath) - if err != nil { - return fmt.Errorf("failed to read SSH config file: %w", err) - } - - existingContent := string(content) - if !strings.HasSuffix(existingContent, "\n") && existingContent != "" { - existingContent += "\n" - } - newContent := existingContent + hostConfig - - err = os.WriteFile(configPath, []byte(newContent), 0o600) - if err != nil { - return fmt.Errorf("failed to update SSH config file: %w", err) - } - - return nil -} - func clusterSelectionPrompt(ctx context.Context, client *databricks.WorkspaceClient) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading clusters.") @@ -174,50 +90,51 @@ func Setup(ctx context.Context, client *databricks.WorkspaceClient, opts SetupOp return err } - configPath, err := resolveConfigPath(opts.SSHConfigPath) + configPath, err := sshconfig.GetMainConfigPathOrDefault(opts.SSHConfigPath) if err != nil { return err } - hostConfig, err := generateHostConfig(opts) + err = sshconfig.EnsureIncludeDirective(configPath) if err != nil { return err } - err = ensureSSHConfigExists(configPath) + hostConfig, err := generateHostConfig(opts) if err != nil { return err } - existingContent, err := os.ReadFile(configPath) + exists, err := sshconfig.HostConfigExists(opts.HostName) if err != nil { - return fmt.Errorf("failed to read SSH config file: %w", err) + return err } - if len(existingContent) > 0 { - exists, err := checkExistingHosts(existingContent, opts.HostName) + recreate := false + if exists { + recreate, err = sshconfig.PromptRecreateConfig(ctx, opts.HostName) if err != nil { return err } - if exists { - cmdio.LogString(ctx, fmt.Sprintf("Host '%s' already exists in the SSH config, skipping setup", opts.HostName)) + if !recreate { + cmdio.LogString(ctx, fmt.Sprintf("Skipping setup for host '%s'", opts.HostName)) return nil } - backupPath, err := createBackup(existingContent, configPath) - if err != nil { - return err - } - cmdio.LogString(ctx, "Created backup of existing SSH config at "+backupPath) } cmdio.LogString(ctx, "Adding new entry to the SSH config:\n"+hostConfig) - err = updateSSHConfigFile(configPath, hostConfig, opts.HostName) + _, err = sshconfig.CreateOrUpdateHostConfig(ctx, opts.HostName, hostConfig, recreate) + if err != nil { + return err + } + + hostConfigPath, err := sshconfig.GetHostConfigPath(opts.HostName) if err != nil { return err } - cmdio.LogString(ctx, fmt.Sprintf("Updated SSH config file at %s with '%s' host", configPath, opts.HostName)) + cmdio.LogString(ctx, fmt.Sprintf("Created SSH config file at %s for '%s' host", hostConfigPath, opts.HostName)) cmdio.LogString(ctx, fmt.Sprintf("You can now connect to the cluster using 'ssh %s' terminal command, or use remote capabilities of your IDE", opts.HostName)) return nil } diff --git a/experimental/ssh/internal/setup/setup_test.go b/experimental/ssh/internal/setup/setup_test.go index aa803dfe1c..975828a3c8 100644 --- a/experimental/ssh/internal/setup/setup_test.go +++ b/experimental/ssh/internal/setup/setup_test.go @@ -118,15 +118,24 @@ func TestGenerateProxyCommand_ServerlessModeWithAccelerator(t *testing.T) { } func TestGenerateHostConfig_Valid(t *testing.T) { - // Create a temporary directory for testing tmpDir := t.TempDir() + clientOpts := client.ClientOptions{ + ClusterID: "cluster-123", + AutoStartCluster: true, + ShutdownDelay: 30 * time.Second, + Profile: "test-profile", + } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + opts := SetupOptions{ HostName: "test-host", ClusterID: "cluster-123", SSHKeysDir: tmpDir, ShutdownDelay: 30 * time.Second, Profile: "test-profile", + ProxyCommand: proxyCommand, } result, err := generateHostConfig(opts) @@ -139,29 +148,35 @@ func TestGenerateHostConfig_Valid(t *testing.T) { assert.Contains(t, result, "--shutdown-delay=30s") assert.Contains(t, result, "--profile=test-profile") - // Check that identity file path is included expectedKeyPath := filepath.Join(tmpDir, "cluster-123") assert.Contains(t, result, fmt.Sprintf(`IdentityFile %q`, expectedKeyPath)) } func TestGenerateHostConfig_WithoutProfile(t *testing.T) { - // Create a temporary directory for testing tmpDir := t.TempDir() + clientOpts := client.ClientOptions{ + ClusterID: "cluster-123", + AutoStartCluster: true, + ShutdownDelay: 30 * time.Second, + Profile: "", + } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + opts := SetupOptions{ HostName: "test-host", ClusterID: "cluster-123", SSHKeysDir: tmpDir, ShutdownDelay: 30 * time.Second, - Profile: "", // No profile + Profile: "", + ProxyCommand: proxyCommand, } result, err := generateHostConfig(opts) assert.NoError(t, err) - // Should not contain profile option assert.NotContains(t, result, "--profile=") - // But should contain other elements assert.Contains(t, result, "Host test-host") assert.Contains(t, result, "--cluster=cluster-123") } @@ -187,181 +202,12 @@ func TestGenerateHostConfig_PathEscaping(t *testing.T) { assert.Contains(t, result, fmt.Sprintf(`IdentityFile %q`, expectedPath)) } -func TestEnsureSSHConfigExists(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, ".ssh", "config") - - err := ensureSSHConfigExists(configPath) - assert.NoError(t, err) - - // Check that directory was created - _, err = os.Stat(filepath.Dir(configPath)) - assert.NoError(t, err) - - // Check that file was created - _, err = os.Stat(configPath) - assert.NoError(t, err) - - // Check that file is empty - content, err := os.ReadFile(configPath) - assert.NoError(t, err) - assert.Empty(t, content) -} - -func TestCheckExistingHosts_NoExistingHost(t *testing.T) { - content := []byte(`Host other-host - User root - HostName example.com - -Host another-host - User admin -`) - exists, err := checkExistingHosts(content, "test-host") - assert.NoError(t, err) - assert.False(t, exists) -} - -func TestCheckExistingHosts_HostAlreadyExists(t *testing.T) { - content := []byte(`Host test-host - User root - HostName example.com - -Host another-host - User admin -`) - exists, err := checkExistingHosts(content, "another-host") - assert.NoError(t, err) - assert.True(t, exists) -} - -func TestCheckExistingHosts_EmptyContent(t *testing.T) { - content := []byte("") - exists, err := checkExistingHosts(content, "test-host") - assert.NoError(t, err) - assert.False(t, exists) -} - -func TestCheckExistingHosts_HostNameWithWhitespaces(t *testing.T) { - content := []byte(` Host test-host `) - exists, err := checkExistingHosts(content, "test-host") - assert.NoError(t, err) - assert.True(t, exists) -} - -func TestCheckExistingHosts_PartialNameMatch(t *testing.T) { - content := []byte(`Host test-host-long`) - exists, err := checkExistingHosts(content, "test-host") - assert.NoError(t, err) - assert.False(t, exists) -} - -func TestCreateBackup_CreatesBackupSuccessfully(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - content := []byte("original content") - - backupPath, err := createBackup(content, configPath) - assert.NoError(t, err) - assert.Equal(t, configPath+".bak", backupPath) - - // Check that backup file was created with correct content - backupContent, err := os.ReadFile(backupPath) - assert.NoError(t, err) - assert.Equal(t, content, backupContent) -} - -func TestCreateBackup_OverwritesExistingBackup(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - backupPath := configPath + ".bak" - - // Create existing backup - oldContent := []byte("old backup") - err := os.WriteFile(backupPath, oldContent, 0o644) - require.NoError(t, err) - - // Create new backup - newContent := []byte("new content") - resultPath, err := createBackup(newContent, configPath) - assert.NoError(t, err) - assert.Equal(t, backupPath, resultPath) - - // Check that backup was overwritten - backupContent, err := os.ReadFile(backupPath) - assert.NoError(t, err) - assert.Equal(t, newContent, backupContent) -} - -func TestUpdateSSHConfigFile_UpdatesSuccessfully(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - - // Create initial config file - initialContent := "# SSH Config\nHost existing\n User root\n" - err := os.WriteFile(configPath, []byte(initialContent), 0o600) - require.NoError(t, err) - - hostConfig := "\nHost new-host\n User root\n HostName example.com\n" - err = updateSSHConfigFile(configPath, hostConfig, "new-host") - assert.NoError(t, err) - - // Check that content was appended - finalContent, err := os.ReadFile(configPath) - assert.NoError(t, err) - expected := initialContent + hostConfig - assert.Equal(t, expected, string(finalContent)) -} - -func TestUpdateSSHConfigFile_AddsNewlineIfMissing(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - - // Create config file without trailing newline - initialContent := "Host existing\n User root" - err := os.WriteFile(configPath, []byte(initialContent), 0o600) - require.NoError(t, err) - - hostConfig := "\nHost new-host\n User root\n" - err = updateSSHConfigFile(configPath, hostConfig, "new-host") - assert.NoError(t, err) - - // Check that newline was added before the new content - finalContent, err := os.ReadFile(configPath) - assert.NoError(t, err) - expected := initialContent + "\n" + hostConfig - assert.Equal(t, expected, string(finalContent)) -} - -func TestUpdateSSHConfigFile_HandlesEmptyFile(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - - // Create empty config file - err := os.WriteFile(configPath, []byte(""), 0o600) - require.NoError(t, err) - - hostConfig := "Host new-host\n User root\n" - err = updateSSHConfigFile(configPath, hostConfig, "new-host") - assert.NoError(t, err) - - // Check that content was added without extra newlines - finalContent, err := os.ReadFile(configPath) - assert.NoError(t, err) - assert.Equal(t, hostConfig, string(finalContent)) -} - -func TestUpdateSSHConfigFile_HandlesReadError(t *testing.T) { - configPath := "/nonexistent/file" - hostConfig := "Host new-host\n" - - err := updateSSHConfigFile(configPath, hostConfig, "new-host") - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to read SSH config file") -} - func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) { ctx := cmdio.MockDiscard(context.Background()) tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + configPath := filepath.Join(tmpDir, "ssh_config") m := mocks.NewMockWorkspaceClient(t) @@ -380,22 +226,43 @@ func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) { Profile: "test-profile", } - err := Setup(ctx, m.WorkspaceClient, opts) + clientOpts := client.ClientOptions{ + ClusterID: opts.ClusterID, + AutoStartCluster: opts.AutoStartCluster, + ShutdownDelay: opts.ShutdownDelay, + Profile: opts.Profile, + } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + opts.ProxyCommand = proxyCommand + + err = Setup(ctx, m.WorkspaceClient, opts) assert.NoError(t, err) - // Check that config file was created + // Check that main config has Include directive content, err := os.ReadFile(configPath) assert.NoError(t, err) - configStr := string(content) - assert.Contains(t, configStr, "Host test-host") - assert.Contains(t, configStr, "--cluster=cluster-123") - assert.Contains(t, configStr, "--profile=test-profile") + assert.Contains(t, configStr, "Include") + // SSH config uses forward slashes on all platforms + assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") + + // Check that host config file was created + hostConfigPath := filepath.Join(tmpDir, ".databricks", "ssh-tunnel-configs", "test-host") + hostContent, err := os.ReadFile(hostConfigPath) + assert.NoError(t, err) + hostConfigStr := string(hostContent) + assert.Contains(t, hostConfigStr, "Host test-host") + assert.Contains(t, hostConfigStr, "--cluster=cluster-123") + assert.Contains(t, hostConfigStr, "--profile=test-profile") } func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) { ctx := cmdio.MockDiscard(context.Background()) tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + configPath := filepath.Join(tmpDir, "ssh_config") // Create existing config file @@ -418,54 +285,34 @@ func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) { ShutdownDelay: 60 * time.Second, } - err = Setup(ctx, m.WorkspaceClient, opts) - assert.NoError(t, err) - - // Check that config file was updated and backup was created - content, err := os.ReadFile(configPath) - assert.NoError(t, err) - - configStr := string(content) - assert.Contains(t, configStr, "# Existing SSH Config") // Original content preserved - assert.Contains(t, configStr, "Host new-host") // New content added - assert.Contains(t, configStr, "--cluster=cluster-456") - - // Check backup was created - backupPath := configPath + ".bak" - backupContent, err := os.ReadFile(backupPath) - assert.NoError(t, err) - assert.Equal(t, existingContent, string(backupContent)) -} - -func TestSetup_DoesNotOverrideExistingHost(t *testing.T) { - ctx := cmdio.MockDiscard(context.Background()) - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "ssh_config") - - // Create config file with existing host - existingContent := "Host duplicate-host\n User root\n" - err := os.WriteFile(configPath, []byte(existingContent), 0o600) - require.NoError(t, err) - - m := mocks.NewMockWorkspaceClient(t) - clustersAPI := m.GetMockClustersAPI() - - clustersAPI.EXPECT().Get(ctx, compute.GetClusterRequest{ClusterId: "cluster-123"}).Return(&compute.ClusterDetails{ - DataSecurityMode: compute.DataSecurityModeSingleUser, - }, nil) - - opts := SetupOptions{ - HostName: "duplicate-host", // Same as existing - ClusterID: "cluster-123", - SSHConfigPath: configPath, - SSHKeysDir: tmpDir, - ShutdownDelay: 30 * time.Second, + clientOpts := client.ClientOptions{ + ClusterID: opts.ClusterID, + AutoStartCluster: opts.AutoStartCluster, + ShutdownDelay: opts.ShutdownDelay, + Profile: opts.Profile, } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + opts.ProxyCommand = proxyCommand err = Setup(ctx, m.WorkspaceClient, opts) assert.NoError(t, err) + // Check that main config has Include directive and preserves existing content content, err := os.ReadFile(configPath) assert.NoError(t, err) - assert.Equal(t, "Host duplicate-host\n User root\n", string(content)) + configStr := string(content) + assert.Contains(t, configStr, "Include") + // SSH config uses forward slashes on all platforms + assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") + assert.Contains(t, configStr, "# Existing SSH Config") + assert.Contains(t, configStr, "Host existing-host") + + // Check that host config file was created + hostConfigPath := filepath.Join(tmpDir, ".databricks", "ssh-tunnel-configs", "new-host") + hostContent, err := os.ReadFile(hostConfigPath) + assert.NoError(t, err) + hostConfigStr := string(hostContent) + assert.Contains(t, hostConfigStr, "Host new-host") + assert.Contains(t, hostConfigStr, "--cluster=cluster-456") } diff --git a/experimental/ssh/internal/sshconfig/sshconfig.go b/experimental/ssh/internal/sshconfig/sshconfig.go new file mode 100644 index 0000000000..3a6713acbf --- /dev/null +++ b/experimental/ssh/internal/sshconfig/sshconfig.go @@ -0,0 +1,172 @@ +package sshconfig + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/databricks/cli/libs/cmdio" +) + +const ( + // configDirName is the directory name for Databricks SSH tunnel configs, relative to the user's home directory. + configDirName = ".databricks/ssh-tunnel-configs" +) + +func GetConfigDir() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get home directory: %w", err) + } + return filepath.Join(homeDir, configDirName), nil +} + +func GetMainConfigPath() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get home directory: %w", err) + } + return filepath.Join(homeDir, ".ssh", "config"), nil +} + +func GetMainConfigPathOrDefault(configPath string) (string, error) { + if configPath != "" { + return configPath, nil + } + return GetMainConfigPath() +} + +func EnsureMainConfigExists(configPath string) error { + _, err := os.Stat(configPath) + if os.IsNotExist(err) { + sshDir := filepath.Dir(configPath) + err = os.MkdirAll(sshDir, 0o700) + if err != nil { + return fmt.Errorf("failed to create SSH directory: %w", err) + } + err = os.WriteFile(configPath, []byte(""), 0o600) + if err != nil { + return fmt.Errorf("failed to create SSH config file: %w", err) + } + return nil + } + return err +} + +func EnsureIncludeDirective(configPath string) error { + configDir, err := GetConfigDir() + if err != nil { + return err + } + + err = os.MkdirAll(configDir, 0o700) + if err != nil { + return fmt.Errorf("failed to create Databricks SSH config directory: %w", err) + } + + err = EnsureMainConfigExists(configPath) + if err != nil { + return err + } + + content, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("failed to read SSH config file: %w", err) + } + + // Convert path to forward slashes for SSH config compatibility across platforms + configDirUnix := filepath.ToSlash(configDir) + + includeLine := fmt.Sprintf("Include %s/*", configDirUnix) + if strings.Contains(string(content), includeLine) { + return nil + } + + newContent := includeLine + "\n" + if len(content) > 0 && !strings.HasPrefix(string(content), "\n") { + newContent += "\n" + } + newContent += string(content) + + err = os.WriteFile(configPath, []byte(newContent), 0o600) + if err != nil { + return fmt.Errorf("failed to update SSH config file with Include directive: %w", err) + } + + return nil +} + +func GetHostConfigPath(hostName string) (string, error) { + configDir, err := GetConfigDir() + if err != nil { + return "", err + } + return filepath.Join(configDir, hostName), nil +} + +func HostConfigExists(hostName string) (bool, error) { + configPath, err := GetHostConfigPath(hostName) + if err != nil { + return false, err + } + _, err = os.Stat(configPath) + if os.IsNotExist(err) { + return false, nil + } + if err != nil { + return false, fmt.Errorf("failed to check host config file: %w", err) + } + return true, nil +} + +// Returns true if the config was created/updated, false if it was skipped. +func CreateOrUpdateHostConfig(ctx context.Context, hostName, hostConfig string, recreate bool) (bool, error) { + configPath, err := GetHostConfigPath(hostName) + if err != nil { + return false, err + } + + exists, err := HostConfigExists(hostName) + if err != nil { + return false, err + } + + if exists && !recreate { + return false, nil + } + + configDir := filepath.Dir(configPath) + err = os.MkdirAll(configDir, 0o700) + if err != nil { + return false, fmt.Errorf("failed to create config directory: %w", err) + } + + err = os.WriteFile(configPath, []byte(hostConfig), 0o600) + if err != nil { + return false, fmt.Errorf("failed to write host config file: %w", err) + } + + return true, nil +} + +func PromptRecreateConfig(ctx context.Context, hostName string) (bool, error) { + response, err := cmdio.AskYesOrNo(ctx, fmt.Sprintf("Host '%s' already exists. Do you want to recreate the config?", hostName)) + if err != nil { + return false, err + } + return response, nil +} + +func GenerateHostConfig(hostName, userName, identityFile, proxyCommand string) string { + return fmt.Sprintf(` +Host %s + User %s + ConnectTimeout 360 + StrictHostKeyChecking accept-new + IdentitiesOnly yes + IdentityFile %q + ProxyCommand %s +`, hostName, userName, identityFile, proxyCommand) +} diff --git a/experimental/ssh/internal/sshconfig/sshconfig_test.go b/experimental/ssh/internal/sshconfig/sshconfig_test.go new file mode 100644 index 0000000000..5fa13923ee --- /dev/null +++ b/experimental/ssh/internal/sshconfig/sshconfig_test.go @@ -0,0 +1,223 @@ +package sshconfig + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/databricks/cli/libs/cmdio" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetConfigDir(t *testing.T) { + dir, err := GetConfigDir() + assert.NoError(t, err) + assert.Contains(t, dir, filepath.Join(".databricks", "ssh-tunnel-configs")) +} + +func TestGetMainConfigPath(t *testing.T) { + path, err := GetMainConfigPath() + assert.NoError(t, err) + assert.Contains(t, path, filepath.Join(".ssh", "config")) +} + +func TestGetMainConfigPathOrDefault(t *testing.T) { + path, err := GetMainConfigPathOrDefault("/custom/path") + assert.NoError(t, err) + assert.Equal(t, "/custom/path", path) + + path, err = GetMainConfigPathOrDefault("") + assert.NoError(t, err) + assert.Contains(t, path, filepath.Join(".ssh", "config")) +} + +func TestEnsureMainConfigExists(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".ssh", "config") + + err := EnsureMainConfigExists(configPath) + assert.NoError(t, err) + + _, err = os.Stat(filepath.Dir(configPath)) + assert.NoError(t, err) + + _, err = os.Stat(configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Empty(t, content) +} + +func TestEnsureIncludeDirective_NewConfig(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".ssh", "config") + + // Set home directory for test + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + err := EnsureIncludeDirective(configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + + configStr := string(content) + assert.Contains(t, configStr, "Include") + // SSH config uses forward slashes on all platforms + assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") +} + +func TestEnsureIncludeDirective_AlreadyExists(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configPath := filepath.Join(tmpDir, ".ssh", "config") + + configDir, err := GetConfigDir() + require.NoError(t, err) + + // Use forward slashes as that's what SSH config uses + configDirUnix := filepath.ToSlash(configDir) + existingContent := "Include " + configDirUnix + "/*\n\nHost example\n User test\n" + err = os.MkdirAll(filepath.Dir(configPath), 0o700) + require.NoError(t, err) + err = os.WriteFile(configPath, []byte(existingContent), 0o600) + require.NoError(t, err) + + err = EnsureIncludeDirective(configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Equal(t, existingContent, string(content)) +} + +func TestEnsureIncludeDirective_PrependsToExisting(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".ssh", "config") + + // Set home directory for test + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + existingContent := "Host example\n User test\n" + err := os.MkdirAll(filepath.Dir(configPath), 0o700) + require.NoError(t, err) + err = os.WriteFile(configPath, []byte(existingContent), 0o600) + require.NoError(t, err) + + err = EnsureIncludeDirective(configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + + configStr := string(content) + assert.Contains(t, configStr, "Include") + // SSH config uses forward slashes on all platforms + assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") + assert.Contains(t, configStr, "Host example") + + includeIndex := len("Include") + hostIndex := len(configStr) - len(existingContent) + assert.Less(t, includeIndex, hostIndex, "Include directive should come before existing content") +} + +func TestGetHostConfigPath(t *testing.T) { + path, err := GetHostConfigPath("test-host") + assert.NoError(t, err) + assert.Contains(t, path, filepath.Join(".databricks", "ssh-tunnel-configs", "test-host")) +} + +func TestHostConfigExists(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + exists, err := HostConfigExists("nonexistent") + assert.NoError(t, err) + assert.False(t, exists) + + configDir := filepath.Join(tmpDir, configDirName) + err = os.MkdirAll(configDir, 0o700) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(configDir, "existing-host"), []byte("config"), 0o600) + require.NoError(t, err) + + exists, err = HostConfigExists("existing-host") + assert.NoError(t, err) + assert.True(t, exists) +} + +func TestCreateOrUpdateHostConfig_NewConfig(t *testing.T) { + ctx := cmdio.MockDiscard(context.Background()) + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + hostConfig := "Host test\n User root\n" + created, err := CreateOrUpdateHostConfig(ctx, "test-host", hostConfig, false) + assert.NoError(t, err) + assert.True(t, created) + + configPath, err := GetHostConfigPath("test-host") + require.NoError(t, err) + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Equal(t, hostConfig, string(content)) +} + +func TestCreateOrUpdateHostConfig_ExistingConfigNoRecreate(t *testing.T) { + ctx := cmdio.MockDiscard(context.Background()) + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configDir := filepath.Join(tmpDir, configDirName) + err := os.MkdirAll(configDir, 0o700) + require.NoError(t, err) + existingConfig := "Host test\n User admin\n" + err = os.WriteFile(filepath.Join(configDir, "test-host"), []byte(existingConfig), 0o600) + require.NoError(t, err) + + newConfig := "Host test\n User root\n" + created, err := CreateOrUpdateHostConfig(ctx, "test-host", newConfig, false) + assert.NoError(t, err) + assert.False(t, created) + + configPath, err := GetHostConfigPath("test-host") + require.NoError(t, err) + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Equal(t, existingConfig, string(content)) +} + +func TestCreateOrUpdateHostConfig_ExistingConfigWithRecreate(t *testing.T) { + ctx := cmdio.MockDiscard(context.Background()) + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configDir := filepath.Join(tmpDir, configDirName) + err := os.MkdirAll(configDir, 0o700) + require.NoError(t, err) + existingConfig := "Host test\n User admin\n" + err = os.WriteFile(filepath.Join(configDir, "test-host"), []byte(existingConfig), 0o600) + require.NoError(t, err) + + newConfig := "Host test\n User root\n" + created, err := CreateOrUpdateHostConfig(ctx, "test-host", newConfig, true) + assert.NoError(t, err) + assert.True(t, created) + + configPath, err := GetHostConfigPath("test-host") + require.NoError(t, err) + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Equal(t, newConfig, string(content)) +}