From d7899afde09b609e59647a0d4f89be3af3a1d20f Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Fri, 19 Dec 2025 10:42:30 +0100 Subject: [PATCH 01/25] WIP: Add serverless GPU compute support to SSH tunnel Jobs API is not yet ready --- experimental/ssh/cmd/connect.go | 21 +- experimental/ssh/cmd/server.go | 4 + experimental/ssh/internal/client/client.go | 201 +++++++++++++----- .../internal/client/ssh-server-bootstrap.py | 48 +++-- experimental/ssh/internal/keys/keys.go | 9 +- experimental/ssh/internal/keys/secrets.go | 12 +- experimental/ssh/internal/server/server.go | 16 +- experimental/ssh/internal/server/sshd.go | 2 +- experimental/ssh/internal/setup/setup.go | 24 ++- experimental/ssh/internal/setup/setup_test.go | 13 +- .../ssh/internal/workspace/workspace.go | 32 +-- 11 files changed, 270 insertions(+), 112 deletions(-) diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go index 1dc9c22337..af130fe4e6 100644 --- a/experimental/ssh/cmd/connect.go +++ b/experimental/ssh/cmd/connect.go @@ -1,6 +1,7 @@ package ssh import ( + "errors" "time" "github.com/databricks/cli/cmd/root" @@ -18,10 +19,18 @@ func newConnectCommand() *cobra.Command { This command establishes an SSH connection to Databricks compute, setting up the SSH server and handling the connection proxy. +For dedicated clusters: + databricks ssh connect --cluster= + +For serverless compute: + databricks ssh connect --name= [--accelerator=] + ` + disclaimer, } var clusterID string + var connectionName string + // var accelerator string var proxyMode bool var serverMetadata string var shutdownDelay time.Duration @@ -31,8 +40,8 @@ the SSH server and handling the connection proxy. var autoStartCluster bool var userKnownHostsFile string - cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (required)") - cmd.MarkFlagRequired("cluster") + cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)") + cmd.Flags().StringVar(&connectionName, "name", "", "Connection name (for serverless compute)") cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects") cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients") cmd.Flags().BoolVar(&autoStartCluster, "auto-start-cluster", true, "Automatically start the cluster if it is not running") @@ -64,9 +73,17 @@ the SSH server and handling the connection proxy. cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() wsClient := cmdctx.WorkspaceClient(ctx) + + if !proxyMode && clusterID == "" && connectionName == "" { + return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the serverless connection name") + } + + // TODO: validate connectionName if provided + opts := client.ClientOptions{ Profile: wsClient.Config.Profile, ClusterID: clusterID, + ConnectionName: connectionName, ProxyMode: proxyMode, ServerMetadata: serverMetadata, ShutdownDelay: shutdownDelay, diff --git a/experimental/ssh/cmd/server.go b/experimental/ssh/cmd/server.go index 77b8c1c156..efe283f28a 100644 --- a/experimental/ssh/cmd/server.go +++ b/experimental/ssh/cmd/server.go @@ -26,12 +26,15 @@ and proxies them to local SSH daemon processes. var maxClients int var shutdownDelay time.Duration var clusterID string + var sessionID string var version string var secretScopeName string var authorizedKeySecretName string cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID") cmd.MarkFlagRequired("cluster") + cmd.Flags().StringVar(&sessionID, "session-id", "", "Session identifier (cluster ID or serverless connection name)") + cmd.MarkFlagRequired("session-id") cmd.Flags().StringVar(&secretScopeName, "secret-scope-name", "", "Databricks secret scope name to store SSH keys") cmd.MarkFlagRequired("secret-scope-name") cmd.Flags().StringVar(&authorizedKeySecretName, "authorized-key-secret-name", "", "Name of the secret containing the client public key") @@ -56,6 +59,7 @@ and proxies them to local SSH daemon processes. wsc := cmdctx.WorkspaceClient(ctx) opts := server.ServerOptions{ ClusterID: clusterID, + SessionID: sessionID, MaxClients: maxClients, ShutdownDelay: shutdownDelay, Version: version, diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index a3c3d78889..0658b23956 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -36,8 +36,12 @@ var sshServerBootstrapScript string var errServerMetadata = errors.New("server metadata error") type ClientOptions struct { - // Id of the cluster to connect to + // Id of the cluster to connect to (for dedicated clusters) ClusterID string + // Connection name (for serverless compute). Used as unique identifier instead of ClusterID. + ConnectionName string + // GPU accelerator type (for serverless compute) + Accelerator string // Delay before shutting down the server after the last client disconnects ShutdownDelay time.Duration // Maximum number of SSH clients @@ -46,7 +50,7 @@ type ClientOptions struct { // to the cluster and proxy all traffic through stdin/stdout. // In the non proxy mode the CLI spawns an ssh client with the ProxyCommand config. ProxyMode bool - // Expected format: ",". + // Expected format: ",,". // If present, the CLI won't attempt to start the server. ServerMetadata string // How often the CLI should reconnect to the server with new auth. @@ -72,6 +76,19 @@ type ClientOptions struct { UserKnownHostsFile string } +func (o *ClientOptions) IsServerlessMode() bool { + return o.ClusterID == "" +} + +// SessionIdentifier returns the unique identifier for the session. +// For dedicated clusters, this is the cluster ID. For serverless, this is the connection name. +func (o *ClientOptions) SessionIdentifier() string { + if o.IsServerlessMode() { + return o.ConnectionName + } + return o.ClusterID +} + func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOptions) error { ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -84,22 +101,30 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt cancel() }() - err := checkClusterState(ctx, client, opts.ClusterID, opts.AutoStartCluster) - if err != nil { - return err + sessionID := opts.SessionIdentifier() + if sessionID == "" { + return errors.New("either --cluster or --name must be provided") } - secretScopeName, err := keys.CreateKeysSecretScope(ctx, client, opts.ClusterID) + // Only check cluster state for dedicated clusters + if !opts.IsServerlessMode() { + err := checkClusterState(ctx, client, opts.ClusterID, opts.AutoStartCluster) + if err != nil { + return err + } + } + + secretScopeName, err := keys.CreateKeysSecretScope(ctx, client, sessionID) if err != nil { return fmt.Errorf("failed to create secret scope: %w", err) } - privateKeyBytes, publicKeyBytes, err := keys.CheckAndGenerateSSHKeyPairFromSecrets(ctx, client, opts.ClusterID, secretScopeName, opts.ClientPrivateKeyName, opts.ClientPublicKeyName) + privateKeyBytes, publicKeyBytes, err := keys.CheckAndGenerateSSHKeyPairFromSecrets(ctx, client, secretScopeName, opts.ClientPrivateKeyName, opts.ClientPublicKeyName) if err != nil { return fmt.Errorf("failed to get or generate SSH key pair from secrets: %w", err) } - keyPath, err := keys.GetLocalSSHKeyPath(opts.ClusterID, opts.SSHKeysDir) + keyPath, err := keys.GetLocalSSHKeyPath(sessionID, opts.SSHKeysDir) if err != nil { return fmt.Errorf("failed to get local keys folder: %w", err) } @@ -113,6 +138,7 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt var userName string var serverPort int + var clusterID string version := build.GetInfo().Version @@ -121,14 +147,15 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt if err := UploadTunnelReleases(ctx, client, version, opts.ReleasesDir); err != nil { return fmt.Errorf("failed to upload ssh-tunnel binaries: %w", err) } - userName, serverPort, err = ensureSSHServerIsRunning(ctx, client, version, secretScopeName, opts) + userName, serverPort, clusterID, err = ensureSSHServerIsRunning(ctx, client, version, secretScopeName, opts) if err != nil { return fmt.Errorf("failed to ensure that ssh server is running: %w", err) } } else { + // Metadata format: ",," metadata := strings.Split(opts.ServerMetadata, ",") - if len(metadata) != 2 { - return fmt.Errorf("invalid metadata: %s, expected format: ,", opts.ServerMetadata) + if len(metadata) < 2 { + return fmt.Errorf("invalid metadata: %s, expected format: ,[,]", opts.ServerMetadata) } userName = metadata[0] if userName == "" { @@ -138,55 +165,88 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt if err != nil { return fmt.Errorf("cannot parse port from metadata: %s, %w", opts.ServerMetadata, err) } + if len(metadata) >= 3 { + clusterID = metadata[2] + } else { + clusterID = opts.ClusterID + } + } + + // For serverless mode, we need the cluster ID from metadata for Driver Proxy connections + if opts.IsServerlessMode() && clusterID == "" { + return errors.New("cluster ID is required for serverless connections but was not found in metadata") } cmdio.LogString(ctx, "Remote user name: "+userName) cmdio.LogString(ctx, fmt.Sprintf("Server port: %d", serverPort)) + if opts.IsServerlessMode() { + cmdio.LogString(ctx, "Cluster ID (from serverless job): "+clusterID) + } if opts.ProxyMode { - return runSSHProxy(ctx, client, serverPort, opts) + return runSSHProxy(ctx, client, serverPort, clusterID, opts) } else { cmdio.LogString(ctx, fmt.Sprintf("Additional SSH arguments: %v", opts.AdditionalArgs)) - return spawnSSHClient(ctx, userName, keyPath, serverPort, opts) + return spawnSSHClient(ctx, userName, keyPath, serverPort, clusterID, opts) } } -func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, clusterID, version string) (int, string, error) { - serverPort, err := sshWorkspace.GetWorkspaceMetadata(ctx, client, version, clusterID) +// getServerMetadata retrieves the server metadata from the workspace and validates it via Driver Proxy. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +// For dedicated clusters, clusterID should be the same as sessionID. +// For serverless, clusterID is read from the workspace metadata. +func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, sessionID, clusterID, version string) (int, string, string, error) { + wsMetadata, err := sshWorkspace.GetWorkspaceMetadata(ctx, client, version, sessionID) if err != nil { - return 0, "", errors.Join(errServerMetadata, err) + return 0, "", "", errors.Join(errServerMetadata, err) + } + cmdio.LogString(ctx, "Workspace metadata: "+fmt.Sprintf("%+v", wsMetadata)) + + // For serverless mode, the cluster ID comes from the metadata + effectiveClusterID := clusterID + if wsMetadata.ClusterID != "" { + effectiveClusterID = wsMetadata.ClusterID } + + if effectiveClusterID == "" { + return 0, "", "", errors.Join(errServerMetadata, errors.New("cluster ID not available in metadata")) + } + workspaceID, err := client.CurrentWorkspaceID(ctx) if err != nil { - return 0, "", err + return 0, "", "", err } - metadataURL := fmt.Sprintf("%s/driver-proxy-api/o/%d/%s/%d/metadata", client.Config.Host, workspaceID, clusterID, serverPort) + metadataURL := fmt.Sprintf("%s/driver-proxy-api/o/%d/%s/%d/metadata", client.Config.Host, workspaceID, effectiveClusterID, wsMetadata.Port) + cmdio.LogString(ctx, "Metadata URL: "+metadataURL) req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil) if err != nil { - return 0, "", err + return 0, "", "", err } if err := client.Config.Authenticate(req); err != nil { - return 0, "", err + return 0, "", "", err } resp, err := http.DefaultClient.Do(req) if err != nil { - return 0, "", err + return 0, "", "", err } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return 0, "", errors.Join(errServerMetadata, fmt.Errorf("server is not ok, status code %d", resp.StatusCode)) - } - bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - return 0, "", err + return 0, "", "", err } - return serverPort, string(bodyBytes), nil + cmdio.LogString(ctx, "Metadata response: "+string(bodyBytes)) + + if resp.StatusCode != http.StatusOK { + return 0, "", "", errors.Join(errServerMetadata, fmt.Errorf("server is not ok, status code %d", resp.StatusCode)) + } + + return wsMetadata.Port, string(bodyBytes), effectiveClusterID, nil } func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (int64, error) { - contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, opts.ClusterID) + sessionID := opts.SessionIdentifier() + contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, sessionID) if err != nil { return 0, fmt.Errorf("failed to get workspace content directory: %w", err) } @@ -196,7 +256,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, return 0, fmt.Errorf("failed to create directory in the remote workspace: %w", err) } - sshTunnelJobName := "ssh-server-bootstrap-" + opts.ClusterID + sshTunnelJobName := "ssh-server-bootstrap-" + sessionID jobNotebookPath := filepath.ToSlash(filepath.Join(contentDir, "ssh-server-bootstrap")) notebookContent := "# Databricks notebook source\n" + sshServerBootstrapScript encodedContent := base64.StdEncoding.EncodeToString([]byte(notebookContent)) @@ -212,26 +272,45 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, return 0, fmt.Errorf("failed to create ssh-tunnel notebook: %w", err) } + baseParams := map[string]string{ + "version": version, + "secretScopeName": secretScopeName, + "authorizedKeySecretName": opts.ClientPublicKeyName, + "shutdownDelay": opts.ShutdownDelay.String(), + "maxClients": strconv.Itoa(opts.MaxClients), + "sessionId": sessionID, + } + + task := jobs.SubmitTask{ + TaskKey: "start_ssh_server", + NotebookTask: &jobs.NotebookTask{ + NotebookPath: jobNotebookPath, + BaseParameters: baseParams, + }, + TimeoutSeconds: int(opts.ServerTimeout.Seconds()), + } + + if opts.IsServerlessMode() { + task.EnvironmentKey = "ssh-tunnel-serverless" + // TODO: Add GPU accelerator configuration when Jobs API supports it + } else { + task.ExistingClusterId = opts.ClusterID + } + submitRun := jobs.SubmitRun{ RunName: sshTunnelJobName, TimeoutSeconds: int(opts.ServerTimeout.Seconds()), - Tasks: []jobs.SubmitTask{ - { - TaskKey: "start_ssh_server", - NotebookTask: &jobs.NotebookTask{ - NotebookPath: jobNotebookPath, - BaseParameters: map[string]string{ - "version": version, - "secretScopeName": secretScopeName, - "authorizedKeySecretName": opts.ClientPublicKeyName, - "shutdownDelay": opts.ShutdownDelay.String(), - "maxClients": strconv.Itoa(opts.MaxClients), - }, - }, - TimeoutSeconds: int(opts.ServerTimeout.Seconds()), - ExistingClusterId: opts.ClusterID, + Tasks: []jobs.SubmitTask{task}, + } + + if opts.IsServerlessMode() { + env := jobs.JobEnvironment{ + EnvironmentKey: "ssh-tunnel-serverless", + Spec: &compute.Environment{ + EnvironmentVersion: "3", }, - }, + } + submitRun.Environments = []jobs.JobEnvironment{env} } cmdio.LogString(ctx, "Submitting a job to start the ssh server...") @@ -243,12 +322,14 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, return runResult.Response.RunId, nil } -func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, serverPort int, opts ClientOptions) error { - proxyCommand, err := setup.GenerateProxyCommand(opts.ClusterID, opts.AutoStartCluster, opts.ShutdownDelay, opts.Profile, userName, serverPort, opts.HandoverTimeout) +func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, serverPort int, clusterID string, opts ClientOptions) error { + proxyCommand, err := setup.GenerateProxyCommand(opts.SessionIdentifier(), clusterID, opts.IsServerlessMode(), opts.AutoStartCluster, opts.ShutdownDelay, opts.Profile, userName, serverPort, opts.HandoverTimeout) if err != nil { return fmt.Errorf("failed to generate ProxyCommand: %w", err) } + hostName := opts.SessionIdentifier() + sshArgs := []string{ "-l", userName, "-i", privateKeyPath, @@ -260,7 +341,7 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server if opts.UserKnownHostsFile != "" { sshArgs = append(sshArgs, "-o", "UserKnownHostsFile="+opts.UserKnownHostsFile) } - sshArgs = append(sshArgs, opts.ClusterID) + sshArgs = append(sshArgs, hostName) sshArgs = append(sshArgs, opts.AdditionalArgs...) cmdio.LogString(ctx, "Launching SSH client: ssh "+strings.Join(sshArgs, " ")) @@ -274,9 +355,9 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server return sshCmd.Run() } -func runSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, serverPort int, opts ClientOptions) error { +func runSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, serverPort int, clusterID string, opts ClientOptions) error { createConn := func(ctx context.Context, connID string) (*websocket.Conn, error) { - return createWebsocketConnection(ctx, client, connID, opts.ClusterID, serverPort) + return createWebsocketConnection(ctx, client, connID, clusterID, serverPort) } requestHandoverTick := func() <-chan time.Time { return time.After(opts.HandoverTimeout) @@ -304,14 +385,18 @@ func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient, return nil } -func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (string, int, error) { - serverPort, userName, err := getServerMetadata(ctx, client, opts.ClusterID, version) +func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (string, int, string, error) { + sessionID := opts.SessionIdentifier() + // For dedicated clusters, use clusterID; for serverless, it will be read from metadata + clusterID := opts.ClusterID + + serverPort, userName, effectiveClusterID, err := getServerMetadata(ctx, client, sessionID, clusterID, version) if errors.Is(err, errServerMetadata) { cmdio.LogString(ctx, "SSH server is not running, starting it now...") runID, err := submitSSHTunnelJob(ctx, client, version, secretScopeName, opts) if err != nil { - return "", 0, fmt.Errorf("failed to submit ssh server job: %w", err) + return "", 0, "", fmt.Errorf("failed to submit ssh server job: %w", err) } cmdio.LogString(ctx, fmt.Sprintf("Job submitted successfully with run ID: %d", runID)) @@ -319,21 +404,21 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC maxRetries := 30 for retries := range maxRetries { if ctx.Err() != nil { - return "", 0, ctx.Err() + return "", 0, "", ctx.Err() } - serverPort, userName, err = getServerMetadata(ctx, client, opts.ClusterID, version) + serverPort, userName, effectiveClusterID, err = getServerMetadata(ctx, client, sessionID, clusterID, version) if err == nil { cmdio.LogString(ctx, "Health check successful, starting ssh WebSocket connection...") break } else if retries < maxRetries-1 { time.Sleep(2 * time.Second) } else { - return "", 0, fmt.Errorf("failed to start the ssh server: %w", err) + return "", 0, "", fmt.Errorf("failed to start the ssh server: %w", err) } } } else if err != nil { - return "", 0, err + return "", 0, "", err } - return userName, serverPort, nil + return userName, serverPort, effectiveClusterID, nil } diff --git a/experimental/ssh/internal/client/ssh-server-bootstrap.py b/experimental/ssh/internal/client/ssh-server-bootstrap.py index 8b8170bf42..8dc0aff16a 100644 --- a/experimental/ssh/internal/client/ssh-server-bootstrap.py +++ b/experimental/ssh/internal/client/ssh-server-bootstrap.py @@ -17,6 +17,7 @@ dbutils.widgets.text("authorizedKeySecretName", "") dbutils.widgets.text("maxClients", "10") dbutils.widgets.text("shutdownDelay", "10m") +dbutils.widgets.text("sessionId", "") # Required: unique identifier for the session def cleanup(): @@ -111,6 +112,9 @@ def run_ssh_server(): shutdown_delay = dbutils.widgets.get("shutdownDelay") max_clients = dbutils.widgets.get("maxClients") + session_id = dbutils.widgets.get("sessionId") + if not session_id: + raise RuntimeError("Session ID is required. Please provide it using the 'sessionId' widget.") arch = platform.machine() if arch == "x86_64": @@ -127,29 +131,29 @@ def run_ssh_server(): binary_path = f"/Workspace/Users/{user_name}/.databricks/ssh-tunnel/{version}/{cli_name}/databricks" + server_args = [ + binary_path, + "ssh", + "server", + f"--cluster={ctx.clusterId}", + f"--session-id={session_id}", + f"--secret-scope-name={secrets_scope}", + f"--authorized-key-secret-name={public_key_secret_name}", + f"--max-clients={max_clients}", + f"--shutdown-delay={shutdown_delay}", + f"--version={version}", + # "info" has enough verbosity for debugging purposes, and "debug" log level prints too much (including secrets) + "--log-level=info", + "--log-format=json", + # To get the server logs: + # 1. Get a job run id from the "databricks ssh connect" output + # 2. Run "databricks jobs get-run " and open a run_page_url + # TODO: file with log rotation + "--log-file=stdout", + ] + try: - subprocess.run( - [ - binary_path, - "ssh", - "server", - f"--cluster={ctx.clusterId}", - f"--secret-scope-name={secrets_scope}", - f"--authorized-key-secret-name={public_key_secret_name}", - f"--max-clients={max_clients}", - f"--shutdown-delay={shutdown_delay}", - f"--version={version}", - # "info" has enough verbosity for debugging purposes, and "debug" log level prints too much (including secrets) - "--log-level=info", - "--log-format=json", - # To get the server logs: - # 1. Get a job run id from the "databricks ssh connect" output - # 2. Run "databricks jobs get-run " and open a run_page_url - # TODO: file with log rotation - "--log-file=stdout", - ], - check=True, - ) + subprocess.run(server_args, check=True) finally: kill_all_children() diff --git a/experimental/ssh/internal/keys/keys.go b/experimental/ssh/internal/keys/keys.go index a1c279c749..735f4f0f1a 100644 --- a/experimental/ssh/internal/keys/keys.go +++ b/experimental/ssh/internal/keys/keys.go @@ -14,8 +14,9 @@ import ( "golang.org/x/crypto/ssh" ) -// We use different client keys for each cluster as a good practice for better isolation and control. -func GetLocalSSHKeyPath(clusterID, keysDir string) (string, error) { +// We use different client keys for each session as a good practice for better isolation and control. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +func GetLocalSSHKeyPath(sessionID, keysDir string) (string, error) { if keysDir == "" { homeDir, err := os.UserHomeDir() if err != nil { @@ -23,7 +24,7 @@ func GetLocalSSHKeyPath(clusterID, keysDir string) (string, error) { } keysDir = filepath.Join(homeDir, ".databricks", "ssh-tunnel-keys") } - return filepath.Join(keysDir, clusterID), nil + return filepath.Join(keysDir, sessionID), nil } func generateSSHKeyPair() ([]byte, []byte, error) { @@ -68,7 +69,7 @@ func SaveSSHKeyPair(keyPath string, privateKeyBytes, publicKeyBytes []byte) erro return nil } -func CheckAndGenerateSSHKeyPairFromSecrets(ctx context.Context, client *databricks.WorkspaceClient, clusterID, secretScopeName, privateKeyName, publicKeyName string) ([]byte, []byte, error) { +func CheckAndGenerateSSHKeyPairFromSecrets(ctx context.Context, client *databricks.WorkspaceClient, secretScopeName, privateKeyName, publicKeyName string) ([]byte, []byte, error) { privateKeyBytes, err := GetSecret(ctx, client, secretScopeName, privateKeyName) if err != nil { privateKeyBytes, publicKeyBytes, err := generateSSHKeyPair() diff --git a/experimental/ssh/internal/keys/secrets.go b/experimental/ssh/internal/keys/secrets.go index 0a7b2c1266..eac692f235 100644 --- a/experimental/ssh/internal/keys/secrets.go +++ b/experimental/ssh/internal/keys/secrets.go @@ -10,12 +10,14 @@ import ( "github.com/databricks/databricks-sdk-go/service/workspace" ) -func CreateKeysSecretScope(ctx context.Context, client *databricks.WorkspaceClient, clusterID string) (string, error) { +// CreateKeysSecretScope creates or retrieves the secret scope for SSH keys. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +func CreateKeysSecretScope(ctx context.Context, client *databricks.WorkspaceClient, sessionID string) (string, error) { me, err := client.CurrentUser.Me(ctx) if err != nil { return "", fmt.Errorf("failed to get current user: %w", err) } - secretScopeName := fmt.Sprintf("%s-%s-ssh-tunnel-keys", me.UserName, clusterID) + secretScopeName := fmt.Sprintf("%s-%s-ssh-tunnel-keys", me.UserName, sessionID) err = client.Secrets.CreateScope(ctx, workspace.CreateScope{ Scope: secretScopeName, }) @@ -53,8 +55,10 @@ func putSecret(ctx context.Context, client *databricks.WorkspaceClient, scope, k return nil } -func PutSecretInScope(ctx context.Context, client *databricks.WorkspaceClient, clusterID, key, value string) (string, error) { - scopeName, err := CreateKeysSecretScope(ctx, client, clusterID) +// PutSecretInScope creates the secret scope if needed and stores the secret. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +func PutSecretInScope(ctx context.Context, client *databricks.WorkspaceClient, sessionID, key, value string) (string, error) { + scopeName, err := CreateKeysSecretScope(ctx, client, sessionID) if err != nil { return "", err } diff --git a/experimental/ssh/internal/server/server.go b/experimental/ssh/internal/server/server.go index 66837cbc72..92fa76050a 100644 --- a/experimental/ssh/internal/server/server.go +++ b/experimental/ssh/internal/server/server.go @@ -29,8 +29,11 @@ type ServerOptions struct { MaxClients int // Delay before shutting down the server when there are no active connections ShutdownDelay time.Duration - // The cluster ID that the client started this server on + // The cluster ID that the client started this server on (required for Driver Proxy connections) ClusterID string + // SessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). + // Used for metadata storage path. Defaults to ClusterID if not set. + SessionID string // The directory to store sshd configuration ConfigDir string // The name of the secrets scope to use for client and server keys @@ -56,7 +59,12 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ServerOpt listenAddr := fmt.Sprintf("0.0.0.0:%d", port) log.Info(ctx, "Starting server on "+listenAddr) - err = workspace.SaveWorkspaceMetadata(ctx, client, opts.Version, opts.ClusterID, port) + // Save metadata including ClusterID (required for Driver Proxy connections in serverless mode) + metadata := &workspace.WorkspaceMetadata{ + Port: port, + ClusterID: opts.ClusterID, + } + err = workspace.SaveWorkspaceMetadata(ctx, client, opts.Version, opts.SessionID, metadata) if err != nil { return fmt.Errorf("failed to save metadata to the workspace: %w", err) } @@ -77,6 +85,10 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ServerOpt connections := proxy.NewConnectionsManager(opts.MaxClients, opts.ShutdownDelay) http.Handle("/ssh", proxy.NewProxyServer(ctx, connections, createServerCommand)) http.HandleFunc("/metadata", serveMetadata) + + http.Handle("/driver-proxy-http/ssh", proxy.NewProxyServer(ctx, connections, createServerCommand)) + http.HandleFunc("/driver-proxy-http/metadata", serveMetadata) + go handleTimeout(ctx, connections.TimedOut, opts.ShutdownDelay) return http.ListenAndServe(listenAddr, nil) diff --git a/experimental/ssh/internal/server/sshd.go b/experimental/ssh/internal/server/sshd.go index 2f45588a33..7a038e73b5 100644 --- a/experimental/ssh/internal/server/sshd.go +++ b/experimental/ssh/internal/server/sshd.go @@ -36,7 +36,7 @@ func prepareSSHDConfig(ctx context.Context, client *databricks.WorkspaceClient, return "", fmt.Errorf("failed to create SSH directory: %w", err) } - privateKeyBytes, publicKeyBytes, err := keys.CheckAndGenerateSSHKeyPairFromSecrets(ctx, client, opts.ClusterID, opts.SecretScopeName, opts.ServerPrivateKeyName, opts.ServerPublicKeyName) + privateKeyBytes, publicKeyBytes, err := keys.CheckAndGenerateSSHKeyPairFromSecrets(ctx, client, opts.SecretScopeName, opts.ServerPrivateKeyName, opts.ServerPublicKeyName) if err != nil { return "", fmt.Errorf("failed to get SSH key pair from secrets: %w", err) } diff --git a/experimental/ssh/internal/setup/setup.go b/experimental/ssh/internal/setup/setup.go index 1c60a1e4f7..726747e4fd 100644 --- a/experimental/ssh/internal/setup/setup.go +++ b/experimental/ssh/internal/setup/setup.go @@ -56,17 +56,31 @@ func resolveConfigPath(configPath string) (string, error) { return filepath.Join(homeDir, ".ssh", "config"), nil } -func GenerateProxyCommand(clusterId string, autoStartCluster bool, shutdownDelay time.Duration, profile, userName string, serverPort int, handoverTimeout time.Duration) (string, error) { +// GenerateProxyCommand generates the ProxyCommand string for SSH config. +// sessionID is the unique identifier (cluster ID for dedicated clusters, connection name for serverless). +// clusterID is the actual cluster ID for Driver Proxy connections (same as sessionID for dedicated clusters, +// but obtained from job metadata for serverless). +func GenerateProxyCommand(sessionID, clusterID string, serverlessMode, autoStartCluster bool, shutdownDelay time.Duration, profile, userName string, serverPort int, handoverTimeout time.Duration) (string, error) { executablePath, err := os.Executable() if err != nil { return "", fmt.Errorf("failed to get current executable path: %w", err) } - proxyCommand := fmt.Sprintf("%q ssh connect --proxy --cluster=%s --auto-start-cluster=%t --shutdown-delay=%s", - executablePath, clusterId, autoStartCluster, shutdownDelay.String()) + var proxyCommand string + if serverlessMode { + proxyCommand = fmt.Sprintf("%q ssh connect --proxy --name=%s --shutdown-delay=%s", + executablePath, sessionID, shutdownDelay.String()) + } else { + proxyCommand = fmt.Sprintf("%q ssh connect --proxy --cluster=%s --auto-start-cluster=%t --shutdown-delay=%s", + executablePath, clusterID, autoStartCluster, shutdownDelay.String()) + } if userName != "" && serverPort != 0 { - proxyCommand += " --metadata=" + userName + "," + strconv.Itoa(serverPort) + if serverlessMode && clusterID != "" { + proxyCommand += " --metadata=" + userName + "," + strconv.Itoa(serverPort) + "," + clusterID + } else { + proxyCommand += " --metadata=" + userName + "," + strconv.Itoa(serverPort) + } } if handoverTimeout > 0 { @@ -86,7 +100,7 @@ func generateHostConfig(opts SetupOptions) (string, error) { return "", fmt.Errorf("failed to get local keys folder: %w", err) } - proxyCommand, err := GenerateProxyCommand(opts.ClusterID, opts.AutoStartCluster, opts.ShutdownDelay, opts.Profile, "", 0, 0) + proxyCommand, err := GenerateProxyCommand(opts.ClusterID, opts.ClusterID, false, opts.AutoStartCluster, opts.ShutdownDelay, opts.Profile, "", 0, 0) if err != nil { return "", fmt.Errorf("failed to generate ProxyCommand: %w", err) } diff --git a/experimental/ssh/internal/setup/setup_test.go b/experimental/ssh/internal/setup/setup_test.go index 7c4cb20925..27a0ced5bc 100644 --- a/experimental/ssh/internal/setup/setup_test.go +++ b/experimental/ssh/internal/setup/setup_test.go @@ -56,7 +56,7 @@ func TestValidateClusterAccess_ClusterNotFound(t *testing.T) { } func TestGenerateProxyCommand(t *testing.T) { - cmd, err := GenerateProxyCommand("cluster-123", true, 45*time.Second, "", "", 0, 0) + cmd, err := GenerateProxyCommand("cluster-123", "cluster-123", false, true, 45*time.Second, "", "", 0, 0) assert.NoError(t, err) assert.Contains(t, cmd, "ssh connect --proxy --cluster=cluster-123 --auto-start-cluster=true --shutdown-delay=45s") assert.NotContains(t, cmd, "--metadata") @@ -65,7 +65,7 @@ func TestGenerateProxyCommand(t *testing.T) { } func TestGenerateProxyCommand_WithExtraArgs(t *testing.T) { - cmd, err := GenerateProxyCommand("cluster-123", true, 45*time.Second, "test-profile", "user", 2222, 2*time.Minute) + cmd, err := GenerateProxyCommand("cluster-123", "cluster-123", false, true, 45*time.Second, "test-profile", "user", 2222, 2*time.Minute) assert.NoError(t, err) assert.Contains(t, cmd, "ssh connect --proxy --cluster=cluster-123 --auto-start-cluster=true --shutdown-delay=45s") assert.Contains(t, cmd, " --metadata=user,2222") @@ -73,6 +73,15 @@ func TestGenerateProxyCommand_WithExtraArgs(t *testing.T) { assert.Contains(t, cmd, " --profile=test-profile") } +func TestGenerateProxyCommand_ServerlessMode(t *testing.T) { + cmd, err := GenerateProxyCommand("my-connection", "serverless-cluster-id", true, false, 45*time.Second, "", "user", 2222, 0) + assert.NoError(t, err) + assert.Contains(t, cmd, "ssh connect --proxy --name=my-connection --shutdown-delay=45s") + assert.Contains(t, cmd, " --metadata=user,2222,serverless-cluster-id") + assert.NotContains(t, cmd, "--cluster=") + assert.NotContains(t, cmd, "--auto-start-cluster") +} + func TestGenerateHostConfig_Valid(t *testing.T) { // Create a temporary directory for testing tmpDir := t.TempDir() diff --git a/experimental/ssh/internal/workspace/workspace.go b/experimental/ssh/internal/workspace/workspace.go index 10e593951a..2f017cbae1 100644 --- a/experimental/ssh/internal/workspace/workspace.go +++ b/experimental/ssh/internal/workspace/workspace.go @@ -16,6 +16,8 @@ const metadataFileName = "metadata.json" type WorkspaceMetadata struct { Port int `json:"port"` + // ClusterID is required for Driver Proxy websocket connections (for any compute type, including serverless) + ClusterID string `json:"cluster_id,omitempty"` } func getWorkspaceRootDir(ctx context.Context, client *databricks.WorkspaceClient) (string, error) { @@ -34,49 +36,55 @@ func GetWorkspaceVersionedDir(ctx context.Context, client *databricks.WorkspaceC return filepath.ToSlash(filepath.Join(contentDir, version)), nil } -func GetWorkspaceContentDir(ctx context.Context, client *databricks.WorkspaceClient, version, clusterID string) (string, error) { +// GetWorkspaceContentDir returns the directory for storing session content. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +func GetWorkspaceContentDir(ctx context.Context, client *databricks.WorkspaceClient, version, sessionID string) (string, error) { contentDir, err := GetWorkspaceVersionedDir(ctx, client, version) if err != nil { return "", fmt.Errorf("failed to get versioned workspace directory: %w", err) } - return filepath.ToSlash(filepath.Join(contentDir, clusterID)), nil + return filepath.ToSlash(filepath.Join(contentDir, sessionID)), nil } -func GetWorkspaceMetadata(ctx context.Context, client *databricks.WorkspaceClient, version, clusterID string) (int, error) { - contentDir, err := GetWorkspaceContentDir(ctx, client, version, clusterID) +// GetWorkspaceMetadata loads session metadata from the workspace. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +func GetWorkspaceMetadata(ctx context.Context, client *databricks.WorkspaceClient, version, sessionID string) (*WorkspaceMetadata, error) { + contentDir, err := GetWorkspaceContentDir(ctx, client, version, sessionID) if err != nil { - return 0, fmt.Errorf("failed to get workspace content directory: %w", err) + return nil, fmt.Errorf("failed to get workspace content directory: %w", err) } metadataPath := filepath.ToSlash(filepath.Join(contentDir, metadataFileName)) content, err := client.Workspace.Download(ctx, metadataPath) if err != nil { - return 0, fmt.Errorf("failed to download metadata file: %w", err) + return nil, fmt.Errorf("failed to download metadata file: %w", err) } defer content.Close() metadataBytes, err := io.ReadAll(content) if err != nil { - return 0, fmt.Errorf("failed to read metadata content: %w", err) + return nil, fmt.Errorf("failed to read metadata content: %w", err) } var metadata WorkspaceMetadata err = json.Unmarshal(metadataBytes, &metadata) if err != nil { - return 0, fmt.Errorf("failed to parse metadata JSON: %w", err) + return nil, fmt.Errorf("failed to parse metadata JSON: %w", err) } - return metadata.Port, nil + return &metadata, nil } -func SaveWorkspaceMetadata(ctx context.Context, client *databricks.WorkspaceClient, version, clusterID string, port int) error { - metadataBytes, err := json.Marshal(WorkspaceMetadata{Port: port}) +// SaveWorkspaceMetadata saves session metadata to the workspace. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +func SaveWorkspaceMetadata(ctx context.Context, client *databricks.WorkspaceClient, version, sessionID string, metadata *WorkspaceMetadata) error { + metadataBytes, err := json.Marshal(metadata) if err != nil { return fmt.Errorf("failed to marshal metadata: %w", err) } - contentDir, err := GetWorkspaceContentDir(ctx, client, version, clusterID) + contentDir, err := GetWorkspaceContentDir(ctx, client, version, sessionID) if err != nil { return fmt.Errorf("failed to get workspace content directory: %w", err) } From 140560ea609d1839a411e2919fa878923f55589a Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Tue, 23 Dec 2025 12:11:25 +0100 Subject: [PATCH 02/25] Add liteswap header value for traffic routing (dev/test only). --- experimental/ssh/cmd/connect.go | 5 +++++ experimental/ssh/internal/client/client.go | 16 +++++++++++----- experimental/ssh/internal/client/websockets.go | 5 ++++- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go index af130fe4e6..bc2f620e28 100644 --- a/experimental/ssh/cmd/connect.go +++ b/experimental/ssh/cmd/connect.go @@ -39,6 +39,7 @@ For serverless compute: var releasesDir string var autoStartCluster bool var userKnownHostsFile string + var liteswap string cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)") cmd.Flags().StringVar(&connectionName, "name", "", "Connection name (for serverless compute)") @@ -59,6 +60,9 @@ For serverless compute: cmd.Flags().StringVar(&userKnownHostsFile, "user-known-hosts-file", "", "Path to user known hosts file for SSH client") cmd.Flags().MarkHidden("user-known-hosts-file") + cmd.Flags().StringVar(&liteswap, "liteswap", "", "Liteswap header value for traffic routing (dev/test only)") + cmd.Flags().MarkHidden("liteswap") + cmd.PreRunE = func(cmd *cobra.Command, args []string) error { // CLI in the proxy mode is executed by the ssh client and can't prompt for input if proxyMode { @@ -95,6 +99,7 @@ For serverless compute: ClientPublicKeyName: clientPublicKeyName, ClientPrivateKeyName: clientPrivateKeyName, UserKnownHostsFile: userKnownHostsFile, + Liteswap: liteswap, AdditionalArgs: args, } return client.Run(ctx, wsClient, opts) diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 0658b23956..7c7869559c 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -74,6 +74,8 @@ type ClientOptions struct { AdditionalArgs []string // Optional path to the user known hosts file. UserKnownHostsFile string + // Liteswap header value for traffic routing (dev/test only). + Liteswap string } func (o *ClientOptions) IsServerlessMode() bool { @@ -107,7 +109,8 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt } // Only check cluster state for dedicated clusters - if !opts.IsServerlessMode() { + // TODO: we can remove liteswap check when we can start serverless GPU clusters via API. + if !opts.IsServerlessMode() && opts.Liteswap == "" { err := checkClusterState(ctx, client, opts.ClusterID, opts.AutoStartCluster) if err != nil { return err @@ -195,7 +198,7 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt // sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). // For dedicated clusters, clusterID should be the same as sessionID. // For serverless, clusterID is read from the workspace metadata. -func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, sessionID, clusterID, version string) (int, string, string, error) { +func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, sessionID, clusterID, version, liteswap string) (int, string, string, error) { wsMetadata, err := sshWorkspace.GetWorkspaceMetadata(ctx, client, version, sessionID) if err != nil { return 0, "", "", errors.Join(errServerMetadata, err) @@ -222,6 +225,9 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, if err != nil { return 0, "", "", err } + if liteswap != "" { + req.Header.Set("x-databricks-traffic-id", "testenv://liteswap/"+liteswap) + } if err := client.Config.Authenticate(req); err != nil { return 0, "", "", err } @@ -357,7 +363,7 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server func runSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, serverPort int, clusterID string, opts ClientOptions) error { createConn := func(ctx context.Context, connID string) (*websocket.Conn, error) { - return createWebsocketConnection(ctx, client, connID, clusterID, serverPort) + return createWebsocketConnection(ctx, client, connID, clusterID, serverPort, opts.Liteswap) } requestHandoverTick := func() <-chan time.Time { return time.After(opts.HandoverTimeout) @@ -390,7 +396,7 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC // For dedicated clusters, use clusterID; for serverless, it will be read from metadata clusterID := opts.ClusterID - serverPort, userName, effectiveClusterID, err := getServerMetadata(ctx, client, sessionID, clusterID, version) + serverPort, userName, effectiveClusterID, err := getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap) if errors.Is(err, errServerMetadata) { cmdio.LogString(ctx, "SSH server is not running, starting it now...") @@ -406,7 +412,7 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC if ctx.Err() != nil { return "", 0, "", ctx.Err() } - serverPort, userName, effectiveClusterID, err = getServerMetadata(ctx, client, sessionID, clusterID, version) + serverPort, userName, effectiveClusterID, err = getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap) if err == nil { cmdio.LogString(ctx, "Health check successful, starting ssh WebSocket connection...") break diff --git a/experimental/ssh/internal/client/websockets.go b/experimental/ssh/internal/client/websockets.go index b1ab20889f..fba53c891e 100644 --- a/experimental/ssh/internal/client/websockets.go +++ b/experimental/ssh/internal/client/websockets.go @@ -9,7 +9,7 @@ import ( "github.com/gorilla/websocket" ) -func createWebsocketConnection(ctx context.Context, client *databricks.WorkspaceClient, connID, clusterID string, serverPort int) (*websocket.Conn, error) { +func createWebsocketConnection(ctx context.Context, client *databricks.WorkspaceClient, connID, clusterID string, serverPort int, liteswap string) (*websocket.Conn, error) { url, err := getProxyURL(ctx, client, connID, clusterID, serverPort) if err != nil { return nil, fmt.Errorf("failed to get proxy URL: %w", err) @@ -20,6 +20,9 @@ func createWebsocketConnection(ctx context.Context, client *databricks.Workspace return nil, fmt.Errorf("failed to create request: %w", err) } + if liteswap != "" { + req.Header.Set("x-databricks-traffic-id", "testenv://liteswap/"+liteswap) + } if err := client.Config.Authenticate(req); err != nil { return nil, fmt.Errorf("failed to authenticate: %w", err) } From 28c330ea79a8326eb4662bcaa7eda404668446e0 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Tue, 23 Dec 2025 14:30:05 +0100 Subject: [PATCH 03/25] Add liteswap option to the ProxyCommand. --- experimental/ssh/internal/setup/setup.go | 8 ++++++-- experimental/ssh/internal/setup/setup_test.go | 6 +++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/experimental/ssh/internal/setup/setup.go b/experimental/ssh/internal/setup/setup.go index 726747e4fd..e0fda32ded 100644 --- a/experimental/ssh/internal/setup/setup.go +++ b/experimental/ssh/internal/setup/setup.go @@ -60,7 +60,7 @@ func resolveConfigPath(configPath string) (string, error) { // sessionID is the unique identifier (cluster ID for dedicated clusters, connection name for serverless). // clusterID is the actual cluster ID for Driver Proxy connections (same as sessionID for dedicated clusters, // but obtained from job metadata for serverless). -func GenerateProxyCommand(sessionID, clusterID string, serverlessMode, autoStartCluster bool, shutdownDelay time.Duration, profile, userName string, serverPort int, handoverTimeout time.Duration) (string, error) { +func GenerateProxyCommand(sessionID, clusterID string, serverlessMode, autoStartCluster bool, shutdownDelay time.Duration, profile, userName string, serverPort int, handoverTimeout time.Duration, liteswap string) (string, error) { executablePath, err := os.Executable() if err != nil { return "", fmt.Errorf("failed to get current executable path: %w", err) @@ -91,6 +91,10 @@ func GenerateProxyCommand(sessionID, clusterID string, serverlessMode, autoStart proxyCommand += " --profile=" + profile } + if liteswap != "" { + proxyCommand += " --liteswap=" + liteswap + } + return proxyCommand, nil } @@ -100,7 +104,7 @@ func generateHostConfig(opts SetupOptions) (string, error) { return "", fmt.Errorf("failed to get local keys folder: %w", err) } - proxyCommand, err := GenerateProxyCommand(opts.ClusterID, opts.ClusterID, false, opts.AutoStartCluster, opts.ShutdownDelay, opts.Profile, "", 0, 0) + proxyCommand, err := GenerateProxyCommand(opts.ClusterID, opts.ClusterID, false, opts.AutoStartCluster, opts.ShutdownDelay, opts.Profile, "", 0, 0, "") if err != nil { return "", fmt.Errorf("failed to generate ProxyCommand: %w", err) } diff --git a/experimental/ssh/internal/setup/setup_test.go b/experimental/ssh/internal/setup/setup_test.go index 27a0ced5bc..f2e1bf6c1b 100644 --- a/experimental/ssh/internal/setup/setup_test.go +++ b/experimental/ssh/internal/setup/setup_test.go @@ -56,7 +56,7 @@ func TestValidateClusterAccess_ClusterNotFound(t *testing.T) { } func TestGenerateProxyCommand(t *testing.T) { - cmd, err := GenerateProxyCommand("cluster-123", "cluster-123", false, true, 45*time.Second, "", "", 0, 0) + cmd, err := GenerateProxyCommand("cluster-123", "cluster-123", false, true, 45*time.Second, "", "", 0, 0, "") assert.NoError(t, err) assert.Contains(t, cmd, "ssh connect --proxy --cluster=cluster-123 --auto-start-cluster=true --shutdown-delay=45s") assert.NotContains(t, cmd, "--metadata") @@ -65,7 +65,7 @@ func TestGenerateProxyCommand(t *testing.T) { } func TestGenerateProxyCommand_WithExtraArgs(t *testing.T) { - cmd, err := GenerateProxyCommand("cluster-123", "cluster-123", false, true, 45*time.Second, "test-profile", "user", 2222, 2*time.Minute) + cmd, err := GenerateProxyCommand("cluster-123", "cluster-123", false, true, 45*time.Second, "test-profile", "user", 2222, 2*time.Minute, "") assert.NoError(t, err) assert.Contains(t, cmd, "ssh connect --proxy --cluster=cluster-123 --auto-start-cluster=true --shutdown-delay=45s") assert.Contains(t, cmd, " --metadata=user,2222") @@ -74,7 +74,7 @@ func TestGenerateProxyCommand_WithExtraArgs(t *testing.T) { } func TestGenerateProxyCommand_ServerlessMode(t *testing.T) { - cmd, err := GenerateProxyCommand("my-connection", "serverless-cluster-id", true, false, 45*time.Second, "", "user", 2222, 0) + cmd, err := GenerateProxyCommand("my-connection", "serverless-cluster-id", true, false, 45*time.Second, "", "user", 2222, 0, "") assert.NoError(t, err) assert.Contains(t, cmd, "ssh connect --proxy --name=my-connection --shutdown-delay=45s") assert.Contains(t, cmd, " --metadata=user,2222,serverless-cluster-id") From aa44ec93dcf510557d63776c91bc7cdd2acab66d Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Fri, 9 Jan 2026 13:23:02 +0100 Subject: [PATCH 04/25] Fix lint error --- experimental/ssh/internal/client/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 7c7869559c..2996ea794d 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -329,7 +329,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, } func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, serverPort int, clusterID string, opts ClientOptions) error { - proxyCommand, err := setup.GenerateProxyCommand(opts.SessionIdentifier(), clusterID, opts.IsServerlessMode(), opts.AutoStartCluster, opts.ShutdownDelay, opts.Profile, userName, serverPort, opts.HandoverTimeout) + proxyCommand, err := setup.GenerateProxyCommand(opts.SessionIdentifier(), clusterID, opts.IsServerlessMode(), opts.AutoStartCluster, opts.ShutdownDelay, opts.Profile, userName, serverPort, opts.HandoverTimeout, opts.Liteswap) if err != nil { return fmt.Errorf("failed to generate ProxyCommand: %w", err) } From 9151da1f1d6e4e0d16e164b7732f57d03275c35c Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Mon, 19 Jan 2026 13:10:08 +0100 Subject: [PATCH 05/25] Add GPU accelerator support for serverless compute --- experimental/ssh/cmd/connect.go | 14 +- experimental/ssh/cmd/constants.go | 1 + experimental/ssh/internal/client/client.go | 293 ++++++++++++++++-- experimental/ssh/internal/setup/setup.go | 52 +--- experimental/ssh/internal/setup/setup_test.go | 41 ++- 5 files changed, 333 insertions(+), 68 deletions(-) diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go index bc2f620e28..6c04db57e2 100644 --- a/experimental/ssh/cmd/connect.go +++ b/experimental/ssh/cmd/connect.go @@ -30,7 +30,7 @@ For serverless compute: var clusterID string var connectionName string - // var accelerator string + var accelerator string var proxyMode bool var serverMetadata string var shutdownDelay time.Duration @@ -43,6 +43,7 @@ For serverless compute: cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)") cmd.Flags().StringVar(&connectionName, "name", "", "Connection name (for serverless compute)") + cmd.Flags().StringVar(&accelerator, "accelerator", "", "GPU accelerator type for serverless compute (GPU_1xA10 or GPU_8xH100)") cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects") cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients") cmd.Flags().BoolVar(&autoStartCluster, "auto-start-cluster", true, "Automatically start the cluster if it is not running") @@ -82,12 +83,22 @@ For serverless compute: return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the serverless connection name") } + if accelerator != "" && connectionName == "" { + return errors.New("--accelerator flag can only be used with serverless compute (--name flag)") + } + + // Remove when we add support for serverless CPU + if connectionName != "" && accelerator == "" { + return errors.New("--name flag requires --accelerator to be set (e.g. for now we only support serverless GPU compute)") + } + // TODO: validate connectionName if provided opts := client.ClientOptions{ Profile: wsClient.Config.Profile, ClusterID: clusterID, ConnectionName: connectionName, + Accelerator: accelerator, ProxyMode: proxyMode, ServerMetadata: serverMetadata, ShutdownDelay: shutdownDelay, @@ -95,6 +106,7 @@ For serverless compute: HandoverTimeout: handoverTimeout, ReleasesDir: releasesDir, ServerTimeout: max(serverTimeout, shutdownDelay), + TaskStartupTimeout: taskStartupTimeout, AutoStartCluster: autoStartCluster, ClientPublicKeyName: clientPublicKeyName, ClientPrivateKeyName: clientPrivateKeyName, diff --git a/experimental/ssh/cmd/constants.go b/experimental/ssh/cmd/constants.go index e812726845..db9f5ccc70 100644 --- a/experimental/ssh/cmd/constants.go +++ b/experimental/ssh/cmd/constants.go @@ -9,6 +9,7 @@ const ( defaultHandoverTimeout = 30 * time.Minute serverTimeout = 24 * time.Hour + taskStartupTimeout = 10 * time.Minute serverPortRange = 100 serverConfigDir = ".ssh-tunnel" serverPrivateKeyName = "server-private-key" diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 2996ea794d..bf2cdaa852 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -1,9 +1,11 @@ package client import ( + "bytes" "context" _ "embed" "encoding/base64" + "encoding/json" "errors" "fmt" "io" @@ -19,7 +21,6 @@ import ( "github.com/databricks/cli/experimental/ssh/internal/keys" "github.com/databricks/cli/experimental/ssh/internal/proxy" - "github.com/databricks/cli/experimental/ssh/internal/setup" sshWorkspace "github.com/databricks/cli/experimental/ssh/internal/workspace" "github.com/databricks/cli/internal/build" "github.com/databricks/cli/libs/cmdio" @@ -35,6 +36,8 @@ var sshServerBootstrapScript string var errServerMetadata = errors.New("server metadata error") +const sshServerTaskKey = "start_ssh_server" + type ClientOptions struct { // Id of the cluster to connect to (for dedicated clusters) ClusterID string @@ -57,6 +60,8 @@ type ClientOptions struct { HandoverTimeout time.Duration // Max amount of time the server process is allowed to live ServerTimeout time.Duration + // Max amount of time to wait for the SSH server task to reach RUNNING state + TaskStartupTimeout time.Duration // Directory for local SSH tunnel development releases. // If not present, the CLI will use github releases with the current version. ReleasesDir string @@ -91,6 +96,58 @@ func (o *ClientOptions) SessionIdentifier() string { return o.ClusterID } +// FormatMetadata formats the server metadata string for use in ProxyCommand. +// Returns empty string if userName is empty or serverPort is zero. +func FormatMetadata(userName string, serverPort int, clusterID string) string { + if userName == "" || serverPort == 0 { + return "" + } + if clusterID != "" { + return fmt.Sprintf("%s,%d,%s", userName, serverPort, clusterID) + } + return fmt.Sprintf("%s,%d", userName, serverPort) +} + +// ToProxyCommand generates the ProxyCommand string for SSH config. +// This method serializes the ClientOptions into a command-line invocation that will +// be parsed back into ClientOptions when the SSH ProxyCommand is executed. +func (o *ClientOptions) ToProxyCommand() (string, error) { + executablePath, err := os.Executable() + if err != nil { + return "", fmt.Errorf("failed to get current executable path: %w", err) + } + + var proxyCommand string + if o.IsServerlessMode() { + proxyCommand = fmt.Sprintf("%q ssh connect --proxy --name=%s --shutdown-delay=%s", + executablePath, o.ConnectionName, o.ShutdownDelay.String()) + if o.Accelerator != "" { + proxyCommand += " --accelerator=" + o.Accelerator + } + } else { + proxyCommand = fmt.Sprintf("%q ssh connect --proxy --cluster=%s --auto-start-cluster=%t --shutdown-delay=%s", + executablePath, o.ClusterID, o.AutoStartCluster, o.ShutdownDelay.String()) + } + + if o.ServerMetadata != "" { + proxyCommand += " --metadata=" + o.ServerMetadata + } + + if o.HandoverTimeout > 0 { + proxyCommand += " --handover-timeout=" + o.HandoverTimeout.String() + } + + if o.Profile != "" { + proxyCommand += " --profile=" + o.Profile + } + + if o.Liteswap != "" { + proxyCommand += " --liteswap=" + o.Liteswap + } + + return proxyCommand, nil +} + func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOptions) error { ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -242,6 +299,7 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, return 0, "", "", err } cmdio.LogString(ctx, "Metadata response: "+string(bodyBytes)) + cmdio.LogString(ctx, "Metadata response status code: "+strconv.Itoa(resp.StatusCode)) if resp.StatusCode != http.StatusOK { return 0, "", "", errors.Join(errServerMetadata, fmt.Errorf("server is not ok, status code %d", resp.StatusCode)) @@ -250,16 +308,16 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, return wsMetadata.Port, string(bodyBytes), effectiveClusterID, nil } -func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (int64, error) { +func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) error { sessionID := opts.SessionIdentifier() contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, sessionID) if err != nil { - return 0, fmt.Errorf("failed to get workspace content directory: %w", err) + return fmt.Errorf("failed to get workspace content directory: %w", err) } err = client.Workspace.MkdirsByPath(ctx, contentDir) if err != nil { - return 0, fmt.Errorf("failed to create directory in the remote workspace: %w", err) + return fmt.Errorf("failed to create directory in the remote workspace: %w", err) } sshTunnelJobName := "ssh-server-bootstrap-" + sessionID @@ -275,7 +333,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, Overwrite: true, }) if err != nil { - return 0, fmt.Errorf("failed to create ssh-tunnel notebook: %w", err) + return fmt.Errorf("failed to create ssh-tunnel notebook: %w", err) } baseParams := map[string]string{ @@ -287,6 +345,13 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, "sessionId": sessionID, } + cmdio.LogString(ctx, "Submitting a job to start the ssh server...") + + // Use manual HTTP call when hardware_accelerator is needed (SDK doesn't support it yet) + if opts.Accelerator != "" { + return submitSSHTunnelJobManual(ctx, client, jobNotebookPath, baseParams, opts) + } + task := jobs.SubmitTask{ TaskKey: "start_ssh_server", NotebookTask: &jobs.NotebookTask{ @@ -298,38 +363,160 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, if opts.IsServerlessMode() { task.EnvironmentKey = "ssh-tunnel-serverless" - // TODO: Add GPU accelerator configuration when Jobs API supports it } else { task.ExistingClusterId = opts.ClusterID } - submitRun := jobs.SubmitRun{ + submitRequest := jobs.SubmitRun{ RunName: sshTunnelJobName, TimeoutSeconds: int(opts.ServerTimeout.Seconds()), Tasks: []jobs.SubmitTask{task}, } if opts.IsServerlessMode() { - env := jobs.JobEnvironment{ - EnvironmentKey: "ssh-tunnel-serverless", - Spec: &compute.Environment{ - EnvironmentVersion: "3", + submitRequest.Environments = []jobs.JobEnvironment{ + { + EnvironmentKey: "ssh-tunnel-serverless", + Spec: &compute.Environment{ + EnvironmentVersion: "3", + }, }, } - submitRun.Environments = []jobs.JobEnvironment{env} } - cmdio.LogString(ctx, "Submitting a job to start the ssh server...") - runResult, err := client.Jobs.Submit(ctx, submitRun) + waiter, err := client.Jobs.Submit(ctx, submitRequest) + if err != nil { + return fmt.Errorf("failed to submit job: %w", err) + } + + cmdio.LogString(ctx, fmt.Sprintf("Job submitted successfully with run ID: %d", waiter.RunId)) + cmdio.LogString(ctx, "Waiting for the SSH server task to start...") + var prevState jobs.RunLifeCycleState + + _, err = waiter.OnProgress(func(run *jobs.Run) { + var sshTask *jobs.RunTask + for i := range run.Tasks { + if run.Tasks[i].TaskKey == sshServerTaskKey { + sshTask = &run.Tasks[i] + break + } + } + + if sshTask == nil || sshTask.State == nil { + return + } + + currentState := sshTask.State.LifeCycleState + + if currentState != prevState { + cmdio.LogString(ctx, fmt.Sprintf("Task status: %s", currentState)) + prevState = currentState + } + + if currentState == jobs.RunLifeCycleStateRunning { + cmdio.LogString(ctx, "SSH server task is now running, proceeding to connect...") + } + }).GetWithTimeout(opts.TaskStartupTimeout) + return err +} + +// submitSSHTunnelJobManual submits a job using manual HTTP call for features not yet supported by the SDK. +// Currently used for hardware_accelerator field which is not yet in the SDK. +func submitSSHTunnelJobManual(ctx context.Context, client *databricks.WorkspaceClient, jobNotebookPath string, baseParams map[string]string, opts ClientOptions) error { + sessionID := opts.SessionIdentifier() + sshTunnelJobName := "ssh-server-bootstrap-" + sessionID + + // Construct the request payload manually to allow custom parameters + task := map[string]any{ + "task_key": sshServerTaskKey, + "notebook_task": map[string]any{ + "notebook_path": jobNotebookPath, + "base_parameters": baseParams, + }, + "timeout_seconds": int(opts.ServerTimeout.Seconds()), + } + + if opts.IsServerlessMode() { + task["environment_key"] = "ssh-tunnel-serverless" + if opts.Accelerator != "" { + cmdio.LogString(ctx, "Using accelerator: "+opts.Accelerator) + task["compute"] = map[string]any{ + "hardware_accelerator": opts.Accelerator, + } + } + } else { + task["existing_cluster_id"] = opts.ClusterID + } + + submitRequest := map[string]any{ + "run_name": sshTunnelJobName, + "timeout_seconds": int(opts.ServerTimeout.Seconds()), + "tasks": []map[string]any{task}, + } + + if opts.IsServerlessMode() { + submitRequest["environments"] = []map[string]any{ + { + "environment_key": "ssh-tunnel-serverless", + "spec": map[string]any{ + "environment_version": "3", + }, + }, + } + } + + requestBody, err := json.Marshal(submitRequest) + if err != nil { + return fmt.Errorf("failed to marshal request body: %w", err) + } + + cmdio.LogString(ctx, "Request body: "+string(requestBody)) + + apiURL := client.Config.Host + "/api/2.1/jobs/runs/submit" + req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(requestBody)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + if err := client.Config.Authenticate(req); err != nil { + return fmt.Errorf("failed to authenticate request: %w", err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return fmt.Errorf("failed to submit job: %w", err) + } + defer resp.Body.Close() + + responseBody, err := io.ReadAll(resp.Body) if err != nil { - return 0, fmt.Errorf("failed to submit job: %w", err) + return fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("failed to submit job, status code %d: %s", resp.StatusCode, string(responseBody)) + } + + var result struct { + RunID int64 `json:"run_id"` + } + if err := json.Unmarshal(responseBody, &result); err != nil { + return fmt.Errorf("failed to parse response: %w", err) } - return runResult.Response.RunId, nil + cmdio.LogString(ctx, fmt.Sprintf("Job submitted successfully with run ID: %d", result.RunID)) + + // For manual submissions we still need to poll manually + return waitForJobToStart(ctx, client, result.RunID, opts.TaskStartupTimeout) } func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, serverPort int, clusterID string, opts ClientOptions) error { - proxyCommand, err := setup.GenerateProxyCommand(opts.SessionIdentifier(), clusterID, opts.IsServerlessMode(), opts.AutoStartCluster, opts.ShutdownDelay, opts.Profile, userName, serverPort, opts.HandoverTimeout, opts.Liteswap) + // Create a copy with metadata for the ProxyCommand + optsWithMetadata := opts + optsWithMetadata.ServerMetadata = FormatMetadata(userName, serverPort, clusterID) + + proxyCommand, err := optsWithMetadata.ToProxyCommand() if err != nil { return fmt.Errorf("failed to generate ProxyCommand: %w", err) } @@ -391,6 +578,73 @@ func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient, return nil } +// waitForJobToStart polls the task status until the SSH server task is in RUNNING state or terminates. +// Returns an error if the task fails to start or if polling times out. +func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient, runID int64, taskStartupTimeout time.Duration) error { + cmdio.LogString(ctx, "Waiting for the SSH server task to start...") + const pollInterval = 2 * time.Second + maxRetries := int(taskStartupTimeout / pollInterval) + var prevState jobs.RunLifecycleStateV2State + + for retries := range maxRetries { + if ctx.Err() != nil { + return ctx.Err() + } + + run, err := client.Jobs.GetRun(ctx, jobs.GetRunRequest{ + RunId: runID, + }) + if err != nil { + return fmt.Errorf("failed to get job run status: %w", err) + } + + // Find the SSH server task + var sshTask *jobs.RunTask + for i := range run.Tasks { + if run.Tasks[i].TaskKey == sshServerTaskKey { + sshTask = &run.Tasks[i] + break + } + } + + if sshTask == nil { + return fmt.Errorf("SSH server task '%s' not found in job run", sshServerTaskKey) + } + + if sshTask.Status == nil { + return fmt.Errorf("task status is nil") + } + + currentState := sshTask.Status.State + + // Print status if it changed + if currentState != prevState { + cmdio.LogString(ctx, fmt.Sprintf("Task status: %s", currentState)) + prevState = currentState + } + + // Check if task is running + if currentState == jobs.RunLifecycleStateV2StateRunning { + cmdio.LogString(ctx, "SSH server task is now running, proceeding to connect...") + return nil + } + + // Check for terminal failure states + if currentState == jobs.RunLifecycleStateV2StateTerminated { + return fmt.Errorf("task terminated before reaching running state") + } + + // Continue polling + if retries < maxRetries-1 { + time.Sleep(pollInterval) + } else { + return fmt.Errorf("timeout waiting for task to start (state: %s)", currentState) + } + } + + return fmt.Errorf("timeout waiting for task to start") +} + func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (string, int, string, error) { sessionID := opts.SessionIdentifier() // For dedicated clusters, use clusterID; for serverless, it will be read from metadata @@ -400,11 +654,10 @@ func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceC if errors.Is(err, errServerMetadata) { cmdio.LogString(ctx, "SSH server is not running, starting it now...") - runID, err := submitSSHTunnelJob(ctx, client, version, secretScopeName, opts) + err := submitSSHTunnelJob(ctx, client, version, secretScopeName, opts) if err != nil { - return "", 0, "", fmt.Errorf("failed to submit ssh server job: %w", err) + return "", 0, "", fmt.Errorf("failed to submit and start ssh server job: %w", err) } - cmdio.LogString(ctx, fmt.Sprintf("Job submitted successfully with run ID: %d", runID)) cmdio.LogString(ctx, "Waiting for the ssh server to start...") maxRetries := 30 diff --git a/experimental/ssh/internal/setup/setup.go b/experimental/ssh/internal/setup/setup.go index e0fda32ded..adfe204427 100644 --- a/experimental/ssh/internal/setup/setup.go +++ b/experimental/ssh/internal/setup/setup.go @@ -7,10 +7,10 @@ import ( "os" "path/filepath" "regexp" - "strconv" "strings" "time" + "github.com/databricks/cli/experimental/ssh/internal/client" "github.com/databricks/cli/experimental/ssh/internal/keys" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/databricks-sdk-go" @@ -56,55 +56,19 @@ func resolveConfigPath(configPath string) (string, error) { return filepath.Join(homeDir, ".ssh", "config"), nil } -// GenerateProxyCommand generates the ProxyCommand string for SSH config. -// sessionID is the unique identifier (cluster ID for dedicated clusters, connection name for serverless). -// clusterID is the actual cluster ID for Driver Proxy connections (same as sessionID for dedicated clusters, -// but obtained from job metadata for serverless). -func GenerateProxyCommand(sessionID, clusterID string, serverlessMode, autoStartCluster bool, shutdownDelay time.Duration, profile, userName string, serverPort int, handoverTimeout time.Duration, liteswap string) (string, error) { - executablePath, err := os.Executable() - if err != nil { - return "", fmt.Errorf("failed to get current executable path: %w", err) - } - - var proxyCommand string - if serverlessMode { - proxyCommand = fmt.Sprintf("%q ssh connect --proxy --name=%s --shutdown-delay=%s", - executablePath, sessionID, shutdownDelay.String()) - } else { - proxyCommand = fmt.Sprintf("%q ssh connect --proxy --cluster=%s --auto-start-cluster=%t --shutdown-delay=%s", - executablePath, clusterID, autoStartCluster, shutdownDelay.String()) - } - - if userName != "" && serverPort != 0 { - if serverlessMode && clusterID != "" { - proxyCommand += " --metadata=" + userName + "," + strconv.Itoa(serverPort) + "," + clusterID - } else { - proxyCommand += " --metadata=" + userName + "," + strconv.Itoa(serverPort) - } - } - - if handoverTimeout > 0 { - proxyCommand += " --handover-timeout=" + handoverTimeout.String() - } - - if profile != "" { - proxyCommand += " --profile=" + profile - } - - if liteswap != "" { - proxyCommand += " --liteswap=" + liteswap - } - - return proxyCommand, nil -} - func generateHostConfig(opts SetupOptions) (string, error) { identityFilePath, err := keys.GetLocalSSHKeyPath(opts.ClusterID, opts.SSHKeysDir) if err != nil { return "", fmt.Errorf("failed to get local keys folder: %w", err) } - proxyCommand, err := GenerateProxyCommand(opts.ClusterID, opts.ClusterID, false, opts.AutoStartCluster, opts.ShutdownDelay, opts.Profile, "", 0, 0, "") + clientOpts := client.ClientOptions{ + ClusterID: opts.ClusterID, + AutoStartCluster: opts.AutoStartCluster, + ShutdownDelay: opts.ShutdownDelay, + Profile: opts.Profile, + } + proxyCommand, err := clientOpts.ToProxyCommand() if err != nil { return "", fmt.Errorf("failed to generate ProxyCommand: %w", err) } diff --git a/experimental/ssh/internal/setup/setup_test.go b/experimental/ssh/internal/setup/setup_test.go index f2e1bf6c1b..aa803dfe1c 100644 --- a/experimental/ssh/internal/setup/setup_test.go +++ b/experimental/ssh/internal/setup/setup_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/databricks/cli/experimental/ssh/internal/client" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/databricks-sdk-go/experimental/mocks" "github.com/databricks/databricks-sdk-go/service/compute" @@ -56,7 +57,12 @@ func TestValidateClusterAccess_ClusterNotFound(t *testing.T) { } func TestGenerateProxyCommand(t *testing.T) { - cmd, err := GenerateProxyCommand("cluster-123", "cluster-123", false, true, 45*time.Second, "", "", 0, 0, "") + opts := client.ClientOptions{ + ClusterID: "cluster-123", + AutoStartCluster: true, + ShutdownDelay: 45 * time.Second, + } + cmd, err := opts.ToProxyCommand() assert.NoError(t, err) assert.Contains(t, cmd, "ssh connect --proxy --cluster=cluster-123 --auto-start-cluster=true --shutdown-delay=45s") assert.NotContains(t, cmd, "--metadata") @@ -65,7 +71,15 @@ func TestGenerateProxyCommand(t *testing.T) { } func TestGenerateProxyCommand_WithExtraArgs(t *testing.T) { - cmd, err := GenerateProxyCommand("cluster-123", "cluster-123", false, true, 45*time.Second, "test-profile", "user", 2222, 2*time.Minute, "") + opts := client.ClientOptions{ + ClusterID: "cluster-123", + AutoStartCluster: true, + ShutdownDelay: 45 * time.Second, + Profile: "test-profile", + ServerMetadata: "user,2222", + HandoverTimeout: 2 * time.Minute, + } + cmd, err := opts.ToProxyCommand() assert.NoError(t, err) assert.Contains(t, cmd, "ssh connect --proxy --cluster=cluster-123 --auto-start-cluster=true --shutdown-delay=45s") assert.Contains(t, cmd, " --metadata=user,2222") @@ -74,9 +88,30 @@ func TestGenerateProxyCommand_WithExtraArgs(t *testing.T) { } func TestGenerateProxyCommand_ServerlessMode(t *testing.T) { - cmd, err := GenerateProxyCommand("my-connection", "serverless-cluster-id", true, false, 45*time.Second, "", "user", 2222, 0, "") + opts := client.ClientOptions{ + ConnectionName: "my-connection", + ShutdownDelay: 45 * time.Second, + ServerMetadata: "user,2222,serverless-cluster-id", + } + cmd, err := opts.ToProxyCommand() + assert.NoError(t, err) + assert.Contains(t, cmd, "ssh connect --proxy --name=my-connection --shutdown-delay=45s") + assert.Contains(t, cmd, " --metadata=user,2222,serverless-cluster-id") + assert.NotContains(t, cmd, "--cluster=") + assert.NotContains(t, cmd, "--auto-start-cluster") +} + +func TestGenerateProxyCommand_ServerlessModeWithAccelerator(t *testing.T) { + opts := client.ClientOptions{ + ConnectionName: "my-connection", + ShutdownDelay: 45 * time.Second, + Accelerator: "GPU_1xA10", + ServerMetadata: "user,2222,serverless-cluster-id", + } + cmd, err := opts.ToProxyCommand() assert.NoError(t, err) assert.Contains(t, cmd, "ssh connect --proxy --name=my-connection --shutdown-delay=45s") + assert.Contains(t, cmd, " --accelerator=GPU_1xA10") assert.Contains(t, cmd, " --metadata=user,2222,serverless-cluster-id") assert.NotContains(t, cmd, "--cluster=") assert.NotContains(t, cmd, "--auto-start-cluster") From f488a38b622a60d0f708f526d46b8fd08a7950da Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Mon, 19 Jan 2026 13:15:44 +0100 Subject: [PATCH 06/25] Don't log metadata response to stdout --- experimental/ssh/internal/client/client.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index bf2cdaa852..0563b95ac7 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -24,6 +24,7 @@ import ( sshWorkspace "github.com/databricks/cli/experimental/ssh/internal/workspace" "github.com/databricks/cli/internal/build" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/log" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/compute" "github.com/databricks/databricks-sdk-go/service/jobs" @@ -298,8 +299,8 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, if err != nil { return 0, "", "", err } - cmdio.LogString(ctx, "Metadata response: "+string(bodyBytes)) - cmdio.LogString(ctx, "Metadata response status code: "+strconv.Itoa(resp.StatusCode)) + log.Debugf(ctx, "Metadata response: %s", string(bodyBytes)) + log.Debugf(ctx, "Metadata response status code: %d", resp.StatusCode) if resp.StatusCode != http.StatusOK { return 0, "", "", errors.Join(errServerMetadata, fmt.Errorf("server is not ok, status code %d", resp.StatusCode)) From 3bf460d1f3950d3be700307cc11d0423ac41aa14 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Mon, 19 Jan 2026 13:17:34 +0100 Subject: [PATCH 07/25] Move metadata URL logging to debug level --- experimental/ssh/internal/client/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 0563b95ac7..26bdeb0eed 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -278,7 +278,7 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, return 0, "", "", err } metadataURL := fmt.Sprintf("%s/driver-proxy-api/o/%d/%s/%d/metadata", client.Config.Host, workspaceID, effectiveClusterID, wsMetadata.Port) - cmdio.LogString(ctx, "Metadata URL: "+metadataURL) + log.Debugf(ctx, "Metadata URL: %s", metadataURL) req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil) if err != nil { return 0, "", "", err From 9961ebf2c55e2b0114e37923a75805b63efa4132 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Tue, 27 Jan 2026 10:32:30 +0100 Subject: [PATCH 08/25] Simplify SSH tunnel job polling --- experimental/ssh/internal/client/client.go | 59 +++++----------------- 1 file changed, 12 insertions(+), 47 deletions(-) diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 26bdeb0eed..c05df4a791 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -26,6 +26,7 @@ import ( "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/log" "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/retries" "github.com/databricks/databricks-sdk-go/service/compute" "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/workspace" @@ -391,34 +392,8 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, } cmdio.LogString(ctx, fmt.Sprintf("Job submitted successfully with run ID: %d", waiter.RunId)) - cmdio.LogString(ctx, "Waiting for the SSH server task to start...") - var prevState jobs.RunLifeCycleState - - _, err = waiter.OnProgress(func(run *jobs.Run) { - var sshTask *jobs.RunTask - for i := range run.Tasks { - if run.Tasks[i].TaskKey == sshServerTaskKey { - sshTask = &run.Tasks[i] - break - } - } - - if sshTask == nil || sshTask.State == nil { - return - } - currentState := sshTask.State.LifeCycleState - - if currentState != prevState { - cmdio.LogString(ctx, fmt.Sprintf("Task status: %s", currentState)) - prevState = currentState - } - - if currentState == jobs.RunLifeCycleStateRunning { - cmdio.LogString(ctx, "SSH server task is now running, proceeding to connect...") - } - }).GetWithTimeout(opts.TaskStartupTimeout) - return err + return waitForJobToStart(ctx, client, waiter.RunId, opts.TaskStartupTimeout) } // submitSSHTunnelJobManual submits a job using manual HTTP call for features not yet supported by the SDK. @@ -583,20 +558,14 @@ func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient, // Returns an error if the task fails to start or if polling times out. func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient, runID int64, taskStartupTimeout time.Duration) error { cmdio.LogString(ctx, "Waiting for the SSH server task to start...") - const pollInterval = 2 * time.Second - maxRetries := int(taskStartupTimeout / pollInterval) var prevState jobs.RunLifecycleStateV2State - for retries := range maxRetries { - if ctx.Err() != nil { - return ctx.Err() - } - + _, err := retries.Poll(ctx, taskStartupTimeout, func() (*jobs.RunTask, *retries.Err) { run, err := client.Jobs.GetRun(ctx, jobs.GetRunRequest{ RunId: runID, }) if err != nil { - return fmt.Errorf("failed to get job run status: %w", err) + return nil, retries.Halt(fmt.Errorf("failed to get job run status: %w", err)) } // Find the SSH server task @@ -609,11 +578,11 @@ func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient, } if sshTask == nil { - return fmt.Errorf("SSH server task '%s' not found in job run", sshServerTaskKey) + return nil, retries.Halt(fmt.Errorf("SSH server task '%s' not found in job run", sshServerTaskKey)) } if sshTask.Status == nil { - return fmt.Errorf("task status is nil") + return nil, retries.Halt(errors.New("task status is nil")) } currentState := sshTask.Status.State @@ -627,23 +596,19 @@ func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient, // Check if task is running if currentState == jobs.RunLifecycleStateV2StateRunning { cmdio.LogString(ctx, "SSH server task is now running, proceeding to connect...") - return nil + return sshTask, nil } // Check for terminal failure states if currentState == jobs.RunLifecycleStateV2StateTerminated { - return fmt.Errorf("task terminated before reaching running state") + return nil, retries.Halt(errors.New("task terminated before reaching running state")) } - // Continue polling - if retries < maxRetries-1 { - time.Sleep(pollInterval) - } else { - return fmt.Errorf("timeout waiting for task to start (state: %s)", currentState) - } - } + // Continue polling for other states + return nil, retries.Continues(fmt.Sprintf("waiting for task to start (current state: %s)", currentState)) + }) - return fmt.Errorf("timeout waiting for task to start") + return err } func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (string, int, string, error) { From 6a6f4f828b304956a2558f4b89c96bfbe7857ba0 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Fri, 6 Feb 2026 14:20:07 +0100 Subject: [PATCH 09/25] Make sure that connection name is not empty in serverless mode --- experimental/ssh/internal/client/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index c05df4a791..b0f9744f0d 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -86,7 +86,7 @@ type ClientOptions struct { } func (o *ClientOptions) IsServerlessMode() bool { - return o.ClusterID == "" + return o.ClusterID == "" && o.ConnectionName != "" } // SessionIdentifier returns the unique identifier for the session. From 7040a1f568c1fac2eb21b040a2c195dbe48c4d85 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Fri, 6 Feb 2026 14:24:42 +0100 Subject: [PATCH 10/25] Extract serverless environment key to a constant --- experimental/ssh/internal/client/client.go | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index b0f9744f0d..839705c4ec 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -38,7 +38,10 @@ var sshServerBootstrapScript string var errServerMetadata = errors.New("server metadata error") -const sshServerTaskKey = "start_ssh_server" +const ( + sshServerTaskKey = "start_ssh_server" + serverlessEnvironmentKey = "ssh_tunnel_serverless" +) type ClientOptions struct { // Id of the cluster to connect to (for dedicated clusters) @@ -355,7 +358,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, } task := jobs.SubmitTask{ - TaskKey: "start_ssh_server", + TaskKey: sshServerTaskKey, NotebookTask: &jobs.NotebookTask{ NotebookPath: jobNotebookPath, BaseParameters: baseParams, @@ -364,7 +367,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, } if opts.IsServerlessMode() { - task.EnvironmentKey = "ssh-tunnel-serverless" + task.EnvironmentKey = serverlessEnvironmentKey } else { task.ExistingClusterId = opts.ClusterID } @@ -378,7 +381,7 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, if opts.IsServerlessMode() { submitRequest.Environments = []jobs.JobEnvironment{ { - EnvironmentKey: "ssh-tunnel-serverless", + EnvironmentKey: serverlessEnvironmentKey, Spec: &compute.Environment{ EnvironmentVersion: "3", }, @@ -413,7 +416,7 @@ func submitSSHTunnelJobManual(ctx context.Context, client *databricks.WorkspaceC } if opts.IsServerlessMode() { - task["environment_key"] = "ssh-tunnel-serverless" + task["environment_key"] = serverlessEnvironmentKey if opts.Accelerator != "" { cmdio.LogString(ctx, "Using accelerator: "+opts.Accelerator) task["compute"] = map[string]any{ @@ -433,7 +436,7 @@ func submitSSHTunnelJobManual(ctx context.Context, client *databricks.WorkspaceC if opts.IsServerlessMode() { submitRequest["environments"] = []map[string]any{ { - "environment_key": "ssh-tunnel-serverless", + "environment_key": serverlessEnvironmentKey, "spec": map[string]any{ "environment_version": "3", }, From a4d274e60fc0b9ebb2e20f3a646fc9f5d13e276d Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Mon, 19 Jan 2026 15:23:10 +0100 Subject: [PATCH 11/25] Add IDE flag and support for IDE integration for vscode and cursor --- experimental/ssh/cmd/connect.go | 11 +- experimental/ssh/cmd/setup.go | 21 +++- experimental/ssh/internal/client/client.go | 138 ++++++++++++++++++++- experimental/ssh/internal/setup/setup.go | 16 +-- 4 files changed, 163 insertions(+), 23 deletions(-) diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go index 6c04db57e2..2d9102dbbd 100644 --- a/experimental/ssh/cmd/connect.go +++ b/experimental/ssh/cmd/connect.go @@ -32,6 +32,7 @@ For serverless compute: var connectionName string var accelerator string var proxyMode bool + var ide string var serverMetadata string var shutdownDelay time.Duration var maxClients int @@ -42,8 +43,9 @@ For serverless compute: var liteswap string cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)") - cmd.Flags().StringVar(&connectionName, "name", "", "Connection name (for serverless compute)") - cmd.Flags().StringVar(&accelerator, "accelerator", "", "GPU accelerator type for serverless compute (GPU_1xA10 or GPU_8xH100)") + cmd.Flags().StringVar(&connectionName, "name", "", "Connection name") + cmd.Flags().StringVar(&accelerator, "accelerator", "", "GPU accelerator type (GPU_1xA10 or GPU_8xH100)") + cmd.Flags().StringVar(&ide, "ide", "", "Open remote IDE window (vscode or cursor)") cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects") cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients") cmd.Flags().BoolVar(&autoStartCluster, "auto-start-cluster", true, "Automatically start the cluster if it is not running") @@ -80,7 +82,7 @@ For serverless compute: wsClient := cmdctx.WorkspaceClient(ctx) if !proxyMode && clusterID == "" && connectionName == "" { - return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the serverless connection name") + return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)") } if accelerator != "" && connectionName == "" { @@ -89,7 +91,7 @@ For serverless compute: // Remove when we add support for serverless CPU if connectionName != "" && accelerator == "" { - return errors.New("--name flag requires --accelerator to be set (e.g. for now we only support serverless GPU compute)") + return errors.New("--name flag requires --accelerator to be set (for now we only support serverless GPU compute)") } // TODO: validate connectionName if provided @@ -100,6 +102,7 @@ For serverless compute: ConnectionName: connectionName, Accelerator: accelerator, ProxyMode: proxyMode, + IDE: ide, ServerMetadata: serverMetadata, ShutdownDelay: shutdownDelay, MaxClients: maxClients, diff --git a/experimental/ssh/cmd/setup.go b/experimental/ssh/cmd/setup.go index 3e4523904c..81b7863666 100644 --- a/experimental/ssh/cmd/setup.go +++ b/experimental/ssh/cmd/setup.go @@ -1,9 +1,11 @@ package ssh import ( + "fmt" "time" "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/experimental/ssh/internal/client" "github.com/databricks/cli/experimental/ssh/internal/setup" "github.com/databricks/cli/libs/cmdctx" "github.com/spf13/cobra" @@ -43,16 +45,27 @@ an SSH host configuration to your SSH config file. cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - client := cmdctx.WorkspaceClient(ctx) - opts := setup.SetupOptions{ + wsClient := cmdctx.WorkspaceClient(ctx) + setupOpts := setup.SetupOptions{ HostName: hostName, ClusterID: clusterID, AutoStartCluster: autoStartCluster, SSHConfigPath: sshConfigPath, ShutdownDelay: shutdownDelay, - Profile: client.Config.Profile, + Profile: wsClient.Config.Profile, } - return setup.Setup(ctx, client, opts) + clientOpts := client.ClientOptions{ + ClusterID: setupOpts.ClusterID, + AutoStartCluster: setupOpts.AutoStartCluster, + ShutdownDelay: setupOpts.ShutdownDelay, + Profile: setupOpts.Profile, + } + proxyCommand, err := clientOpts.ToProxyCommand() + if err != nil { + return fmt.Errorf("failed to generate ProxyCommand: %w", err) + } + setupOpts.ProxyCommand = proxyCommand + return setup.Setup(ctx, wsClient, setupOpts) } return cmd diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 839705c4ec..23572bf323 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -14,6 +14,7 @@ import ( "os/exec" "os/signal" "path/filepath" + "regexp" "strconv" "strings" "syscall" @@ -58,6 +59,8 @@ type ClientOptions struct { // to the cluster and proxy all traffic through stdin/stdout. // In the non proxy mode the CLI spawns an ssh client with the ProxyCommand config. ProxyMode bool + // Open remote IDE window with a specific ssh config (empty, 'vscode', or 'cursor') + IDE string // Expected format: ",,". // If present, the CLI won't attempt to start the server. ServerMetadata string @@ -171,8 +174,7 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt } // Only check cluster state for dedicated clusters - // TODO: we can remove liteswap check when we can start serverless GPU clusters via API. - if !opts.IsServerlessMode() && opts.Liteswap == "" { + if !opts.IsServerlessMode() { err := checkClusterState(ctx, client, opts.ClusterID, opts.AutoStartCluster) if err != nil { return err @@ -250,12 +252,144 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt if opts.ProxyMode { return runSSHProxy(ctx, client, serverPort, clusterID, opts) + } else if opts.IDE != "" { + return runIDE(ctx, client, userName, keyPath, serverPort, clusterID, opts) } else { cmdio.LogString(ctx, fmt.Sprintf("Additional SSH arguments: %v", opts.AdditionalArgs)) return spawnSSHClient(ctx, userName, keyPath, serverPort, clusterID, opts) } } +func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error { + // Validate IDE value + if opts.IDE != "vscode" && opts.IDE != "cursor" { + return fmt.Errorf("invalid IDE value: %s, expected 'vscode' or 'cursor'", opts.IDE) + } + + // Get connection name + connectionName := opts.SessionIdentifier() + if connectionName == "" { + return errors.New("connection name is required for IDE integration") + } + + // Get Databricks user name for the workspace path + currentUser, err := client.CurrentUser.Me(ctx) + if err != nil { + return fmt.Errorf("failed to get current user: %w", err) + } + databricksUserName := currentUser.UserName + + // Ensure SSH config entry exists + configPath, err := getSSHConfigPath() + if err != nil { + return fmt.Errorf("failed to get SSH config path: %w", err) + } + + err = ensureSSHConfigEntry(ctx, configPath, connectionName, userName, keyPath, serverPort, clusterID, opts) + if err != nil { + return fmt.Errorf("failed to ensure SSH config entry: %w", err) + } + + // Determine the IDE command + ideCommand := "code" + if opts.IDE == "cursor" { + ideCommand = "cursor" + } + + // Construct the remote SSH URI + // Format: ssh-remote+@ /Workspace/Users// + remoteURI := fmt.Sprintf("ssh-remote+%s@%s", userName, connectionName) + remotePath := fmt.Sprintf("/Workspace/Users/%s/", databricksUserName) + + cmdio.LogString(ctx, fmt.Sprintf("Launching %s with remote URI: %s and path: %s", opts.IDE, remoteURI, remotePath)) + + // Launch the IDE + ideCmd := exec.CommandContext(ctx, ideCommand, "--remote", remoteURI, remotePath) + ideCmd.Stdout = os.Stdout + ideCmd.Stderr = os.Stderr + + return ideCmd.Run() +} + +func getSSHConfigPath() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get home directory: %w", err) + } + return filepath.Join(homeDir, ".ssh", "config"), nil +} + +func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error { + // Ensure SSH directory and config file exist + sshDir := filepath.Dir(configPath) + err := os.MkdirAll(sshDir, 0o700) + if err != nil { + return fmt.Errorf("failed to create SSH directory: %w", err) + } + + _, err = os.Stat(configPath) + if os.IsNotExist(err) { + err = os.WriteFile(configPath, []byte(""), 0o600) + if err != nil { + return fmt.Errorf("failed to create SSH config file: %w", err) + } + } else if err != nil { + return fmt.Errorf("failed to check SSH config file: %w", err) + } + + // Check if the host entry already exists + existingContent, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("failed to read SSH config file: %w", err) + } + + hostPattern := fmt.Sprintf(`(?m)^\s*Host\s+%s\s*$`, regexp.QuoteMeta(hostName)) + matched, err := regexp.Match(hostPattern, existingContent) + if err != nil { + return fmt.Errorf("failed to check for existing host: %w", err) + } + + if matched { + cmdio.LogString(ctx, fmt.Sprintf("SSH config entry for '%s' already exists", hostName)) + return nil + } + + // Generate ProxyCommand with server metadata + optsWithMetadata := opts + optsWithMetadata.ServerMetadata = FormatMetadata(userName, serverPort, clusterID) + + proxyCommand, err := optsWithMetadata.ToProxyCommand() + if err != nil { + return fmt.Errorf("failed to generate ProxyCommand: %w", err) + } + + // Generate host config + hostConfig := fmt.Sprintf(` +Host %s + User %s + ConnectTimeout 360 + StrictHostKeyChecking accept-new + IdentitiesOnly yes + IdentityFile %q + ProxyCommand %s +`, hostName, userName, keyPath, proxyCommand) + + // Append to config file + content := string(existingContent) + if !strings.HasSuffix(content, "\n") && content != "" { + content += "\n" + } + content += hostConfig + + err = os.WriteFile(configPath, []byte(content), 0o600) + if err != nil { + return fmt.Errorf("failed to update SSH config file: %w", err) + } + + cmdio.LogString(ctx, fmt.Sprintf("Added SSH config entry for '%s'", hostName)) + return nil +} + // getServerMetadata retrieves the server metadata from the workspace and validates it via Driver Proxy. // sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). // For dedicated clusters, clusterID should be the same as sessionID. diff --git a/experimental/ssh/internal/setup/setup.go b/experimental/ssh/internal/setup/setup.go index adfe204427..0d76071a65 100644 --- a/experimental/ssh/internal/setup/setup.go +++ b/experimental/ssh/internal/setup/setup.go @@ -10,7 +10,6 @@ import ( "strings" "time" - "github.com/databricks/cli/experimental/ssh/internal/client" "github.com/databricks/cli/experimental/ssh/internal/keys" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/databricks-sdk-go" @@ -32,6 +31,8 @@ type SetupOptions struct { SSHKeysDir string // Optional auth profile name. If present, will be added as --profile flag to the ProxyCommand Profile string + // Proxy command to use for the SSH connection + ProxyCommand string } func validateClusterAccess(ctx context.Context, client *databricks.WorkspaceClient, clusterID string) error { @@ -62,17 +63,6 @@ func generateHostConfig(opts SetupOptions) (string, error) { return "", fmt.Errorf("failed to get local keys folder: %w", err) } - clientOpts := client.ClientOptions{ - ClusterID: opts.ClusterID, - AutoStartCluster: opts.AutoStartCluster, - ShutdownDelay: opts.ShutdownDelay, - Profile: opts.Profile, - } - proxyCommand, err := clientOpts.ToProxyCommand() - if err != nil { - return "", fmt.Errorf("failed to generate ProxyCommand: %w", err) - } - hostConfig := fmt.Sprintf(` Host %s User root @@ -81,7 +71,7 @@ Host %s IdentitiesOnly yes IdentityFile %q ProxyCommand %s -`, opts.HostName, identityFilePath, proxyCommand) +`, opts.HostName, identityFilePath, opts.ProxyCommand) return hostConfig, nil } From a0009d8951b8bd5c70e00178f503a232363aeb76 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Thu, 5 Feb 2026 17:16:33 +0100 Subject: [PATCH 12/25] Mark hidden flags for serverless compute, remove manual Http call --- experimental/ssh/cmd/connect.go | 10 +- experimental/ssh/internal/client/client.go | 104 ++------------------- 2 files changed, 13 insertions(+), 101 deletions(-) diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go index 2d9102dbbd..87876d85ef 100644 --- a/experimental/ssh/cmd/connect.go +++ b/experimental/ssh/cmd/connect.go @@ -43,13 +43,17 @@ For serverless compute: var liteswap string cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)") - cmd.Flags().StringVar(&connectionName, "name", "", "Connection name") - cmd.Flags().StringVar(&accelerator, "accelerator", "", "GPU accelerator type (GPU_1xA10 or GPU_8xH100)") - cmd.Flags().StringVar(&ide, "ide", "", "Open remote IDE window (vscode or cursor)") cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects") cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients") cmd.Flags().BoolVar(&autoStartCluster, "auto-start-cluster", true, "Automatically start the cluster if it is not running") + cmd.Flags().StringVar(&connectionName, "name", "", "Connection name (for serverless compute)") + cmd.Flags().MarkHidden("name") + cmd.Flags().StringVar(&accelerator, "accelerator", "", "GPU accelerator type (GPU_1xA10 or GPU_8xH100)") + cmd.Flags().MarkHidden("accelerator") + cmd.Flags().StringVar(&ide, "ide", "", "Open remote IDE window (vscode or cursor)") + cmd.Flags().MarkHidden("ide") + cmd.Flags().BoolVar(&proxyMode, "proxy", false, "ProxyCommand mode") cmd.Flags().MarkHidden("proxy") cmd.Flags().StringVar(&serverMetadata, "metadata", "", "Metadata of the running SSH server (format: ,)") diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 23572bf323..2fcc08f373 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -1,11 +1,9 @@ package client import ( - "bytes" "context" _ "embed" "encoding/base64" - "encoding/json" "errors" "fmt" "io" @@ -486,11 +484,6 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, cmdio.LogString(ctx, "Submitting a job to start the ssh server...") - // Use manual HTTP call when hardware_accelerator is needed (SDK doesn't support it yet) - if opts.Accelerator != "" { - return submitSSHTunnelJobManual(ctx, client, jobNotebookPath, baseParams, opts) - } - task := jobs.SubmitTask{ TaskKey: sshServerTaskKey, NotebookTask: &jobs.NotebookTask{ @@ -502,6 +495,12 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, if opts.IsServerlessMode() { task.EnvironmentKey = serverlessEnvironmentKey + if opts.Accelerator != "" { + cmdio.LogString(ctx, "Using accelerator: "+opts.Accelerator) + task.Compute = &jobs.Compute{ + HardwareAccelerator: compute.HardwareAcceleratorType(opts.Accelerator), + } + } } else { task.ExistingClusterId = opts.ClusterID } @@ -533,97 +532,6 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, return waitForJobToStart(ctx, client, waiter.RunId, opts.TaskStartupTimeout) } -// submitSSHTunnelJobManual submits a job using manual HTTP call for features not yet supported by the SDK. -// Currently used for hardware_accelerator field which is not yet in the SDK. -func submitSSHTunnelJobManual(ctx context.Context, client *databricks.WorkspaceClient, jobNotebookPath string, baseParams map[string]string, opts ClientOptions) error { - sessionID := opts.SessionIdentifier() - sshTunnelJobName := "ssh-server-bootstrap-" + sessionID - - // Construct the request payload manually to allow custom parameters - task := map[string]any{ - "task_key": sshServerTaskKey, - "notebook_task": map[string]any{ - "notebook_path": jobNotebookPath, - "base_parameters": baseParams, - }, - "timeout_seconds": int(opts.ServerTimeout.Seconds()), - } - - if opts.IsServerlessMode() { - task["environment_key"] = serverlessEnvironmentKey - if opts.Accelerator != "" { - cmdio.LogString(ctx, "Using accelerator: "+opts.Accelerator) - task["compute"] = map[string]any{ - "hardware_accelerator": opts.Accelerator, - } - } - } else { - task["existing_cluster_id"] = opts.ClusterID - } - - submitRequest := map[string]any{ - "run_name": sshTunnelJobName, - "timeout_seconds": int(opts.ServerTimeout.Seconds()), - "tasks": []map[string]any{task}, - } - - if opts.IsServerlessMode() { - submitRequest["environments"] = []map[string]any{ - { - "environment_key": serverlessEnvironmentKey, - "spec": map[string]any{ - "environment_version": "3", - }, - }, - } - } - - requestBody, err := json.Marshal(submitRequest) - if err != nil { - return fmt.Errorf("failed to marshal request body: %w", err) - } - - cmdio.LogString(ctx, "Request body: "+string(requestBody)) - - apiURL := client.Config.Host + "/api/2.1/jobs/runs/submit" - req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewReader(requestBody)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - if err := client.Config.Authenticate(req); err != nil { - return fmt.Errorf("failed to authenticate request: %w", err) - } - - resp, err := http.DefaultClient.Do(req) - if err != nil { - return fmt.Errorf("failed to submit job: %w", err) - } - defer resp.Body.Close() - - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response body: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("failed to submit job, status code %d: %s", resp.StatusCode, string(responseBody)) - } - - var result struct { - RunID int64 `json:"run_id"` - } - if err := json.Unmarshal(responseBody, &result); err != nil { - return fmt.Errorf("failed to parse response: %w", err) - } - - cmdio.LogString(ctx, fmt.Sprintf("Job submitted successfully with run ID: %d", result.RunID)) - - // For manual submissions we still need to poll manually - return waitForJobToStart(ctx, client, result.RunID, opts.TaskStartupTimeout) -} - func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, serverPort int, clusterID string, opts ClientOptions) error { // Create a copy with metadata for the ProxyCommand optsWithMetadata := opts From 51b8b2686ed9e38823a69a56170db3e07e6a6636 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Fri, 6 Feb 2026 14:38:58 +0100 Subject: [PATCH 13/25] Separate IDE options into constants --- experimental/ssh/internal/client/client.go | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 2fcc08f373..6c4695b6ac 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -40,6 +40,11 @@ var errServerMetadata = errors.New("server metadata error") const ( sshServerTaskKey = "start_ssh_server" serverlessEnvironmentKey = "ssh_tunnel_serverless" + + VSCodeOption = "vscode" + VSCodeCommand = "code" + CursorOption = "cursor" + CursorCommand = "cursor" ) type ClientOptions struct { @@ -259,12 +264,10 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt } func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error { - // Validate IDE value - if opts.IDE != "vscode" && opts.IDE != "cursor" { - return fmt.Errorf("invalid IDE value: %s, expected 'vscode' or 'cursor'", opts.IDE) + if opts.IDE != VSCodeOption && opts.IDE != CursorOption { + return fmt.Errorf("invalid IDE value: %s, expected '%s' or '%s'", opts.IDE, VSCodeOption, CursorOption) } - // Get connection name connectionName := opts.SessionIdentifier() if connectionName == "" { return errors.New("connection name is required for IDE integration") @@ -288,10 +291,9 @@ func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, k return fmt.Errorf("failed to ensure SSH config entry: %w", err) } - // Determine the IDE command - ideCommand := "code" - if opts.IDE == "cursor" { - ideCommand = "cursor" + ideCommand := VSCodeCommand + if opts.IDE == CursorOption { + ideCommand = CursorCommand } // Construct the remote SSH URI @@ -301,7 +303,6 @@ func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, k cmdio.LogString(ctx, fmt.Sprintf("Launching %s with remote URI: %s and path: %s", opts.IDE, remoteURI, remotePath)) - // Launch the IDE ideCmd := exec.CommandContext(ctx, ideCommand, "--remote", remoteURI, remotePath) ideCmd.Stdout = os.Stdout ideCmd.Stderr = os.Stderr From abb3157927b16003b6a5458ebc8c568f6419c0aa Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Tue, 10 Feb 2026 13:15:48 +0100 Subject: [PATCH 14/25] Move unnecessary logging to debug level --- experimental/ssh/internal/client/client.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 6c4695b6ac..6b3e6b6382 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -398,7 +398,7 @@ func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, if err != nil { return 0, "", "", errors.Join(errServerMetadata, err) } - cmdio.LogString(ctx, "Workspace metadata: "+fmt.Sprintf("%+v", wsMetadata)) + log.Debugf(ctx, "Workspace metadata: %+v", wsMetadata) // For serverless mode, the cluster ID comes from the metadata effectiveClusterID := clusterID @@ -559,8 +559,7 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server sshArgs = append(sshArgs, hostName) sshArgs = append(sshArgs, opts.AdditionalArgs...) - cmdio.LogString(ctx, "Launching SSH client: ssh "+strings.Join(sshArgs, " ")) - + log.Debugf(ctx, "Launching SSH client: ssh %s", strings.Join(sshArgs, " ")) sshCmd := exec.CommandContext(ctx, "ssh", sshArgs...) sshCmd.Stdin = os.Stdin From 627868cd771f031e491bce90f6166ecfeb94d187 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Thu, 12 Feb 2026 15:17:00 +0100 Subject: [PATCH 15/25] Remove dedicated cluster and serverless compute specific documentation from connect command help message --- experimental/ssh/cmd/connect.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go index 87876d85ef..4eca1aee7b 100644 --- a/experimental/ssh/cmd/connect.go +++ b/experimental/ssh/cmd/connect.go @@ -19,12 +19,6 @@ func newConnectCommand() *cobra.Command { This command establishes an SSH connection to Databricks compute, setting up the SSH server and handling the connection proxy. -For dedicated clusters: - databricks ssh connect --cluster= - -For serverless compute: - databricks ssh connect --name= [--accelerator=] - ` + disclaimer, } From a6a0cf9a7e7e6caa8983731e47678c996ebf4cfd Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Fri, 23 Jan 2026 10:45:15 +0100 Subject: [PATCH 16/25] New config logic --- experimental/ssh/internal/client/client.go | 65 +--- experimental/ssh/internal/setup/setup.go | 110 ++----- experimental/ssh/internal/setup/setup_test.go | 301 +++++------------- .../ssh/internal/sshconfig/sshconfig.go | 176 ++++++++++ .../ssh/internal/sshconfig/sshconfig_test.go | 212 ++++++++++++ 5 files changed, 497 insertions(+), 367 deletions(-) create mode 100644 experimental/ssh/internal/sshconfig/sshconfig.go create mode 100644 experimental/ssh/internal/sshconfig/sshconfig_test.go diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 6b3e6b6382..e2beb2fc51 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -12,7 +12,6 @@ import ( "os/exec" "os/signal" "path/filepath" - "regexp" "strconv" "strings" "syscall" @@ -20,6 +19,7 @@ import ( "github.com/databricks/cli/experimental/ssh/internal/keys" "github.com/databricks/cli/experimental/ssh/internal/proxy" + "github.com/databricks/cli/experimental/ssh/internal/sshconfig" sshWorkspace "github.com/databricks/cli/experimental/ssh/internal/workspace" "github.com/databricks/cli/internal/build" "github.com/databricks/cli/libs/cmdio" @@ -281,7 +281,7 @@ func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, k databricksUserName := currentUser.UserName // Ensure SSH config entry exists - configPath, err := getSSHConfigPath() + configPath, err := sshconfig.GetMainConfigPath() if err != nil { return fmt.Errorf("failed to get SSH config path: %w", err) } @@ -310,47 +310,11 @@ func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, k return ideCmd.Run() } -func getSSHConfigPath() (string, error) { - homeDir, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("failed to get home directory: %w", err) - } - return filepath.Join(homeDir, ".ssh", "config"), nil -} - func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error { - // Ensure SSH directory and config file exist - sshDir := filepath.Dir(configPath) - err := os.MkdirAll(sshDir, 0o700) - if err != nil { - return fmt.Errorf("failed to create SSH directory: %w", err) - } - - _, err = os.Stat(configPath) - if os.IsNotExist(err) { - err = os.WriteFile(configPath, []byte(""), 0o600) - if err != nil { - return fmt.Errorf("failed to create SSH config file: %w", err) - } - } else if err != nil { - return fmt.Errorf("failed to check SSH config file: %w", err) - } - - // Check if the host entry already exists - existingContent, err := os.ReadFile(configPath) - if err != nil { - return fmt.Errorf("failed to read SSH config file: %w", err) - } - - hostPattern := fmt.Sprintf(`(?m)^\s*Host\s+%s\s*$`, regexp.QuoteMeta(hostName)) - matched, err := regexp.Match(hostPattern, existingContent) + // Ensure the Include directive exists in the main SSH config + err := sshconfig.EnsureIncludeDirective(configPath) if err != nil { - return fmt.Errorf("failed to check for existing host: %w", err) - } - - if matched { - cmdio.LogString(ctx, fmt.Sprintf("SSH config entry for '%s' already exists", hostName)) - return nil + return err } // Generate ProxyCommand with server metadata @@ -373,16 +337,21 @@ Host %s ProxyCommand %s `, hostName, userName, keyPath, proxyCommand) - // Append to config file - content := string(existingContent) - if !strings.HasSuffix(content, "\n") && content != "" { - content += "\n" + // Check if the host config already exists + exists, err := sshconfig.HostConfigExists(hostName) + if err != nil { + return err + } + + if exists { + cmdio.LogString(ctx, fmt.Sprintf("SSH config entry for '%s' already exists, skipping", hostName)) + return nil } - content += hostConfig - err = os.WriteFile(configPath, []byte(content), 0o600) + // Create the host config file + _, err = sshconfig.CreateOrUpdateHostConfig(ctx, hostName, hostConfig, false) if err != nil { - return fmt.Errorf("failed to update SSH config file: %w", err) + return err } cmdio.LogString(ctx, fmt.Sprintf("Added SSH config entry for '%s'", hostName)) diff --git a/experimental/ssh/internal/setup/setup.go b/experimental/ssh/internal/setup/setup.go index 0d76071a65..3cd5d41182 100644 --- a/experimental/ssh/internal/setup/setup.go +++ b/experimental/ssh/internal/setup/setup.go @@ -4,13 +4,10 @@ import ( "context" "errors" "fmt" - "os" - "path/filepath" - "regexp" - "strings" "time" "github.com/databricks/cli/experimental/ssh/internal/keys" + "github.com/databricks/cli/experimental/ssh/internal/sshconfig" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/compute" @@ -46,17 +43,6 @@ func validateClusterAccess(ctx context.Context, client *databricks.WorkspaceClie return nil } -func resolveConfigPath(configPath string) (string, error) { - if configPath != "" { - return configPath, nil - } - homeDir, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("failed to get home directory: %w", err) - } - return filepath.Join(homeDir, ".ssh", "config"), nil -} - func generateHostConfig(opts SetupOptions) (string, error) { identityFilePath, err := keys.GetLocalSSHKeyPath(opts.ClusterID, opts.SSHKeysDir) if err != nil { @@ -76,67 +62,6 @@ Host %s return hostConfig, nil } -func ensureSSHConfigExists(configPath string) error { - _, err := os.Stat(configPath) - if os.IsNotExist(err) { - sshDir := filepath.Dir(configPath) - err = os.MkdirAll(sshDir, 0o700) - if err != nil { - return fmt.Errorf("failed to create SSH directory: %w", err) - } - err = os.WriteFile(configPath, []byte(""), 0o600) - if err != nil { - return fmt.Errorf("failed to create SSH config file: %w", err) - } - return nil - } else if err != nil { - return fmt.Errorf("failed to check SSH config file: %w", err) - } - return nil -} - -func checkExistingHosts(content []byte, hostName string) (bool, error) { - existingContent := string(content) - pattern := fmt.Sprintf(`(?m)^\s*Host\s+%s\s*$`, regexp.QuoteMeta(hostName)) - matched, err := regexp.MatchString(pattern, existingContent) - if err != nil { - return false, fmt.Errorf("failed to check for existing host: %w", err) - } - if matched { - return true, nil - } - return false, nil -} - -func createBackup(content []byte, configPath string) (string, error) { - backupPath := configPath + ".bak" - err := os.WriteFile(backupPath, content, 0o600) - if err != nil { - return backupPath, fmt.Errorf("failed to create backup of SSH config file: %w", err) - } - return backupPath, nil -} - -func updateSSHConfigFile(configPath, hostConfig, hostName string) error { - content, err := os.ReadFile(configPath) - if err != nil { - return fmt.Errorf("failed to read SSH config file: %w", err) - } - - existingContent := string(content) - if !strings.HasSuffix(existingContent, "\n") && existingContent != "" { - existingContent += "\n" - } - newContent := existingContent + hostConfig - - err = os.WriteFile(configPath, []byte(newContent), 0o600) - if err != nil { - return fmt.Errorf("failed to update SSH config file: %w", err) - } - - return nil -} - func clusterSelectionPrompt(ctx context.Context, client *databricks.WorkspaceClient) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading clusters.") @@ -174,50 +99,51 @@ func Setup(ctx context.Context, client *databricks.WorkspaceClient, opts SetupOp return err } - configPath, err := resolveConfigPath(opts.SSHConfigPath) + configPath, err := sshconfig.GetMainConfigPathOrDefault(opts.SSHConfigPath) if err != nil { return err } - hostConfig, err := generateHostConfig(opts) + err = sshconfig.EnsureIncludeDirective(configPath) if err != nil { return err } - err = ensureSSHConfigExists(configPath) + hostConfig, err := generateHostConfig(opts) if err != nil { return err } - existingContent, err := os.ReadFile(configPath) + exists, err := sshconfig.HostConfigExists(opts.HostName) if err != nil { - return fmt.Errorf("failed to read SSH config file: %w", err) + return err } - if len(existingContent) > 0 { - exists, err := checkExistingHosts(existingContent, opts.HostName) + recreate := false + if exists { + recreate, err = sshconfig.PromptRecreateConfig(ctx, opts.HostName) if err != nil { return err } - if exists { - cmdio.LogString(ctx, fmt.Sprintf("Host '%s' already exists in the SSH config, skipping setup", opts.HostName)) + if !recreate { + cmdio.LogString(ctx, fmt.Sprintf("Skipping setup for host '%s'", opts.HostName)) return nil } - backupPath, err := createBackup(existingContent, configPath) - if err != nil { - return err - } - cmdio.LogString(ctx, "Created backup of existing SSH config at "+backupPath) } cmdio.LogString(ctx, "Adding new entry to the SSH config:\n"+hostConfig) - err = updateSSHConfigFile(configPath, hostConfig, opts.HostName) + _, err = sshconfig.CreateOrUpdateHostConfig(ctx, opts.HostName, hostConfig, recreate) + if err != nil { + return err + } + + hostConfigPath, err := sshconfig.GetHostConfigPath(opts.HostName) if err != nil { return err } - cmdio.LogString(ctx, fmt.Sprintf("Updated SSH config file at %s with '%s' host", configPath, opts.HostName)) + cmdio.LogString(ctx, fmt.Sprintf("Created SSH config file at %s for '%s' host", hostConfigPath, opts.HostName)) cmdio.LogString(ctx, fmt.Sprintf("You can now connect to the cluster using 'ssh %s' terminal command, or use remote capabilities of your IDE", opts.HostName)) return nil } diff --git a/experimental/ssh/internal/setup/setup_test.go b/experimental/ssh/internal/setup/setup_test.go index aa803dfe1c..43e4f105b8 100644 --- a/experimental/ssh/internal/setup/setup_test.go +++ b/experimental/ssh/internal/setup/setup_test.go @@ -118,15 +118,24 @@ func TestGenerateProxyCommand_ServerlessModeWithAccelerator(t *testing.T) { } func TestGenerateHostConfig_Valid(t *testing.T) { - // Create a temporary directory for testing tmpDir := t.TempDir() + clientOpts := client.ClientOptions{ + ClusterID: "cluster-123", + AutoStartCluster: true, + ShutdownDelay: 30 * time.Second, + Profile: "test-profile", + } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + opts := SetupOptions{ HostName: "test-host", ClusterID: "cluster-123", SSHKeysDir: tmpDir, ShutdownDelay: 30 * time.Second, Profile: "test-profile", + ProxyCommand: proxyCommand, } result, err := generateHostConfig(opts) @@ -139,29 +148,35 @@ func TestGenerateHostConfig_Valid(t *testing.T) { assert.Contains(t, result, "--shutdown-delay=30s") assert.Contains(t, result, "--profile=test-profile") - // Check that identity file path is included expectedKeyPath := filepath.Join(tmpDir, "cluster-123") assert.Contains(t, result, fmt.Sprintf(`IdentityFile %q`, expectedKeyPath)) } func TestGenerateHostConfig_WithoutProfile(t *testing.T) { - // Create a temporary directory for testing tmpDir := t.TempDir() + clientOpts := client.ClientOptions{ + ClusterID: "cluster-123", + AutoStartCluster: true, + ShutdownDelay: 30 * time.Second, + Profile: "", + } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + opts := SetupOptions{ HostName: "test-host", ClusterID: "cluster-123", SSHKeysDir: tmpDir, ShutdownDelay: 30 * time.Second, - Profile: "", // No profile + Profile: "", + ProxyCommand: proxyCommand, } result, err := generateHostConfig(opts) assert.NoError(t, err) - // Should not contain profile option assert.NotContains(t, result, "--profile=") - // But should contain other elements assert.Contains(t, result, "Host test-host") assert.Contains(t, result, "--cluster=cluster-123") } @@ -187,181 +202,13 @@ func TestGenerateHostConfig_PathEscaping(t *testing.T) { assert.Contains(t, result, fmt.Sprintf(`IdentityFile %q`, expectedPath)) } -func TestEnsureSSHConfigExists(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, ".ssh", "config") - - err := ensureSSHConfigExists(configPath) - assert.NoError(t, err) - - // Check that directory was created - _, err = os.Stat(filepath.Dir(configPath)) - assert.NoError(t, err) - - // Check that file was created - _, err = os.Stat(configPath) - assert.NoError(t, err) - - // Check that file is empty - content, err := os.ReadFile(configPath) - assert.NoError(t, err) - assert.Empty(t, content) -} - -func TestCheckExistingHosts_NoExistingHost(t *testing.T) { - content := []byte(`Host other-host - User root - HostName example.com - -Host another-host - User admin -`) - exists, err := checkExistingHosts(content, "test-host") - assert.NoError(t, err) - assert.False(t, exists) -} - -func TestCheckExistingHosts_HostAlreadyExists(t *testing.T) { - content := []byte(`Host test-host - User root - HostName example.com - -Host another-host - User admin -`) - exists, err := checkExistingHosts(content, "another-host") - assert.NoError(t, err) - assert.True(t, exists) -} - -func TestCheckExistingHosts_EmptyContent(t *testing.T) { - content := []byte("") - exists, err := checkExistingHosts(content, "test-host") - assert.NoError(t, err) - assert.False(t, exists) -} - -func TestCheckExistingHosts_HostNameWithWhitespaces(t *testing.T) { - content := []byte(` Host test-host `) - exists, err := checkExistingHosts(content, "test-host") - assert.NoError(t, err) - assert.True(t, exists) -} - -func TestCheckExistingHosts_PartialNameMatch(t *testing.T) { - content := []byte(`Host test-host-long`) - exists, err := checkExistingHosts(content, "test-host") - assert.NoError(t, err) - assert.False(t, exists) -} - -func TestCreateBackup_CreatesBackupSuccessfully(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - content := []byte("original content") - - backupPath, err := createBackup(content, configPath) - assert.NoError(t, err) - assert.Equal(t, configPath+".bak", backupPath) - - // Check that backup file was created with correct content - backupContent, err := os.ReadFile(backupPath) - assert.NoError(t, err) - assert.Equal(t, content, backupContent) -} - -func TestCreateBackup_OverwritesExistingBackup(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - backupPath := configPath + ".bak" - - // Create existing backup - oldContent := []byte("old backup") - err := os.WriteFile(backupPath, oldContent, 0o644) - require.NoError(t, err) - - // Create new backup - newContent := []byte("new content") - resultPath, err := createBackup(newContent, configPath) - assert.NoError(t, err) - assert.Equal(t, backupPath, resultPath) - - // Check that backup was overwritten - backupContent, err := os.ReadFile(backupPath) - assert.NoError(t, err) - assert.Equal(t, newContent, backupContent) -} - -func TestUpdateSSHConfigFile_UpdatesSuccessfully(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - - // Create initial config file - initialContent := "# SSH Config\nHost existing\n User root\n" - err := os.WriteFile(configPath, []byte(initialContent), 0o600) - require.NoError(t, err) - - hostConfig := "\nHost new-host\n User root\n HostName example.com\n" - err = updateSSHConfigFile(configPath, hostConfig, "new-host") - assert.NoError(t, err) - - // Check that content was appended - finalContent, err := os.ReadFile(configPath) - assert.NoError(t, err) - expected := initialContent + hostConfig - assert.Equal(t, expected, string(finalContent)) -} - -func TestUpdateSSHConfigFile_AddsNewlineIfMissing(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - - // Create config file without trailing newline - initialContent := "Host existing\n User root" - err := os.WriteFile(configPath, []byte(initialContent), 0o600) - require.NoError(t, err) - - hostConfig := "\nHost new-host\n User root\n" - err = updateSSHConfigFile(configPath, hostConfig, "new-host") - assert.NoError(t, err) - - // Check that newline was added before the new content - finalContent, err := os.ReadFile(configPath) - assert.NoError(t, err) - expected := initialContent + "\n" + hostConfig - assert.Equal(t, expected, string(finalContent)) -} - -func TestUpdateSSHConfigFile_HandlesEmptyFile(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - - // Create empty config file - err := os.WriteFile(configPath, []byte(""), 0o600) - require.NoError(t, err) - - hostConfig := "Host new-host\n User root\n" - err = updateSSHConfigFile(configPath, hostConfig, "new-host") - assert.NoError(t, err) - - // Check that content was added without extra newlines - finalContent, err := os.ReadFile(configPath) - assert.NoError(t, err) - assert.Equal(t, hostConfig, string(finalContent)) -} - -func TestUpdateSSHConfigFile_HandlesReadError(t *testing.T) { - configPath := "/nonexistent/file" - hostConfig := "Host new-host\n" - - err := updateSSHConfigFile(configPath, hostConfig, "new-host") - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to read SSH config file") -} - func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) { ctx := cmdio.MockDiscard(context.Background()) tmpDir := t.TempDir() + homeDir := os.Getenv("HOME") + defer func() { os.Setenv("HOME", homeDir) }() + os.Setenv("HOME", tmpDir) + configPath := filepath.Join(tmpDir, "ssh_config") m := mocks.NewMockWorkspaceClient(t) @@ -380,22 +227,43 @@ func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) { Profile: "test-profile", } - err := Setup(ctx, m.WorkspaceClient, opts) + clientOpts := client.ClientOptions{ + ClusterID: opts.ClusterID, + AutoStartCluster: opts.AutoStartCluster, + ShutdownDelay: opts.ShutdownDelay, + Profile: opts.Profile, + } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + opts.ProxyCommand = proxyCommand + + err = Setup(ctx, m.WorkspaceClient, opts) assert.NoError(t, err) - // Check that config file was created + // Check that main config has Include directive content, err := os.ReadFile(configPath) assert.NoError(t, err) - configStr := string(content) - assert.Contains(t, configStr, "Host test-host") - assert.Contains(t, configStr, "--cluster=cluster-123") - assert.Contains(t, configStr, "--profile=test-profile") + assert.Contains(t, configStr, "Include") + assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") + + // Check that host config file was created + hostConfigPath := filepath.Join(tmpDir, ".databricks/ssh-tunnel-configs/test-host") + hostContent, err := os.ReadFile(hostConfigPath) + assert.NoError(t, err) + hostConfigStr := string(hostContent) + assert.Contains(t, hostConfigStr, "Host test-host") + assert.Contains(t, hostConfigStr, "--cluster=cluster-123") + assert.Contains(t, hostConfigStr, "--profile=test-profile") } func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) { ctx := cmdio.MockDiscard(context.Background()) tmpDir := t.TempDir() + homeDir := os.Getenv("HOME") + defer func() { os.Setenv("HOME", homeDir) }() + os.Setenv("HOME", tmpDir) + configPath := filepath.Join(tmpDir, "ssh_config") // Create existing config file @@ -418,54 +286,33 @@ func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) { ShutdownDelay: 60 * time.Second, } - err = Setup(ctx, m.WorkspaceClient, opts) - assert.NoError(t, err) - - // Check that config file was updated and backup was created - content, err := os.ReadFile(configPath) - assert.NoError(t, err) - - configStr := string(content) - assert.Contains(t, configStr, "# Existing SSH Config") // Original content preserved - assert.Contains(t, configStr, "Host new-host") // New content added - assert.Contains(t, configStr, "--cluster=cluster-456") - - // Check backup was created - backupPath := configPath + ".bak" - backupContent, err := os.ReadFile(backupPath) - assert.NoError(t, err) - assert.Equal(t, existingContent, string(backupContent)) -} - -func TestSetup_DoesNotOverrideExistingHost(t *testing.T) { - ctx := cmdio.MockDiscard(context.Background()) - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "ssh_config") - - // Create config file with existing host - existingContent := "Host duplicate-host\n User root\n" - err := os.WriteFile(configPath, []byte(existingContent), 0o600) - require.NoError(t, err) - - m := mocks.NewMockWorkspaceClient(t) - clustersAPI := m.GetMockClustersAPI() - - clustersAPI.EXPECT().Get(ctx, compute.GetClusterRequest{ClusterId: "cluster-123"}).Return(&compute.ClusterDetails{ - DataSecurityMode: compute.DataSecurityModeSingleUser, - }, nil) - - opts := SetupOptions{ - HostName: "duplicate-host", // Same as existing - ClusterID: "cluster-123", - SSHConfigPath: configPath, - SSHKeysDir: tmpDir, - ShutdownDelay: 30 * time.Second, + clientOpts := client.ClientOptions{ + ClusterID: opts.ClusterID, + AutoStartCluster: opts.AutoStartCluster, + ShutdownDelay: opts.ShutdownDelay, + Profile: opts.Profile, } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + opts.ProxyCommand = proxyCommand err = Setup(ctx, m.WorkspaceClient, opts) assert.NoError(t, err) + // Check that main config has Include directive and preserves existing content content, err := os.ReadFile(configPath) assert.NoError(t, err) - assert.Equal(t, "Host duplicate-host\n User root\n", string(content)) + configStr := string(content) + assert.Contains(t, configStr, "Include") + assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") + assert.Contains(t, configStr, "# Existing SSH Config") + assert.Contains(t, configStr, "Host existing-host") + + // Check that host config file was created + hostConfigPath := filepath.Join(tmpDir, ".databricks/ssh-tunnel-configs/new-host") + hostContent, err := os.ReadFile(hostConfigPath) + assert.NoError(t, err) + hostConfigStr := string(hostContent) + assert.Contains(t, hostConfigStr, "Host new-host") + assert.Contains(t, hostConfigStr, "--cluster=cluster-456") } diff --git a/experimental/ssh/internal/sshconfig/sshconfig.go b/experimental/ssh/internal/sshconfig/sshconfig.go new file mode 100644 index 0000000000..2b3e886c87 --- /dev/null +++ b/experimental/ssh/internal/sshconfig/sshconfig.go @@ -0,0 +1,176 @@ +package sshconfig + +import ( + "context" + "fmt" + "os" + "path/filepath" + "regexp" + "strings" + + "github.com/databricks/cli/libs/cmdio" +) + +const ( + // ConfigDirName is the directory name for Databricks SSH tunnel configs + ConfigDirName = ".databricks/ssh-tunnel-configs" +) + +// GetConfigDir returns the path to the Databricks SSH tunnel configs directory. +func GetConfigDir() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get home directory: %w", err) + } + return filepath.Join(homeDir, ConfigDirName), nil +} + +// GetMainConfigPath returns the path to the main SSH config file. +func GetMainConfigPath() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get home directory: %w", err) + } + return filepath.Join(homeDir, ".ssh", "config"), nil +} + +// GetMainConfigPathOrDefault returns the provided path if non-empty, otherwise returns the default. +func GetMainConfigPathOrDefault(configPath string) (string, error) { + if configPath != "" { + return configPath, nil + } + return GetMainConfigPath() +} + +// EnsureMainConfigExists ensures the main SSH config file exists. +func EnsureMainConfigExists(configPath string) error { + _, err := os.Stat(configPath) + if os.IsNotExist(err) { + sshDir := filepath.Dir(configPath) + err = os.MkdirAll(sshDir, 0o700) + if err != nil { + return fmt.Errorf("failed to create SSH directory: %w", err) + } + err = os.WriteFile(configPath, []byte(""), 0o600) + if err != nil { + return fmt.Errorf("failed to create SSH config file: %w", err) + } + return nil + } + return err +} + +// EnsureIncludeDirective ensures the Include directive for Databricks configs exists in the main SSH config. +func EnsureIncludeDirective(configPath string) error { + configDir, err := GetConfigDir() + if err != nil { + return err + } + + err = os.MkdirAll(configDir, 0o700) + if err != nil { + return fmt.Errorf("failed to create Databricks SSH config directory: %w", err) + } + + err = EnsureMainConfigExists(configPath) + if err != nil { + return err + } + + content, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("failed to read SSH config file: %w", err) + } + + // Check if Include directive already exists + includePattern := fmt.Sprintf(`(?m)^\s*Include\s+.*%s/\*\s*$`, regexp.QuoteMeta(ConfigDirName)) + matched, err := regexp.Match(includePattern, content) + if err != nil { + return fmt.Errorf("failed to check for existing Include directive: %w", err) + } + + if matched { + return nil + } + + // Prepend the Include directive + includeLine := fmt.Sprintf("Include %s/*\n", configDir) + newContent := includeLine + if len(content) > 0 && !strings.HasPrefix(string(content), "\n") { + newContent += "\n" + } + newContent += string(content) + + err = os.WriteFile(configPath, []byte(newContent), 0o600) + if err != nil { + return fmt.Errorf("failed to update SSH config file with Include directive: %w", err) + } + + return nil +} + +// GetHostConfigPath returns the path to a specific host's config file. +func GetHostConfigPath(hostName string) (string, error) { + configDir, err := GetConfigDir() + if err != nil { + return "", err + } + return filepath.Join(configDir, hostName), nil +} + +// HostConfigExists checks if a host config file already exists. +func HostConfigExists(hostName string) (bool, error) { + configPath, err := GetHostConfigPath(hostName) + if err != nil { + return false, err + } + _, err = os.Stat(configPath) + if os.IsNotExist(err) { + return false, nil + } + if err != nil { + return false, fmt.Errorf("failed to check host config file: %w", err) + } + return true, nil +} + +// CreateOrUpdateHostConfig creates or updates a host config file. +// Returns true if the config was created/updated, false if it was skipped. +func CreateOrUpdateHostConfig(ctx context.Context, hostName, hostConfig string, recreate bool) (bool, error) { + configPath, err := GetHostConfigPath(hostName) + if err != nil { + return false, err + } + + exists, err := HostConfigExists(hostName) + if err != nil { + return false, err + } + + if exists && !recreate { + return false, nil + } + + // Ensure the config directory exists + configDir := filepath.Dir(configPath) + err = os.MkdirAll(configDir, 0o700) + if err != nil { + return false, fmt.Errorf("failed to create config directory: %w", err) + } + + err = os.WriteFile(configPath, []byte(hostConfig), 0o600) + if err != nil { + return false, fmt.Errorf("failed to write host config file: %w", err) + } + + return true, nil +} + +// PromptRecreateConfig asks the user if they want to recreate an existing config. +func PromptRecreateConfig(ctx context.Context, hostName string) (bool, error) { + response, err := cmdio.AskYesOrNo(ctx, fmt.Sprintf("Host '%s' already exists. Do you want to recreate the config?", hostName)) + if err != nil { + return false, err + } + return response, nil +} diff --git a/experimental/ssh/internal/sshconfig/sshconfig_test.go b/experimental/ssh/internal/sshconfig/sshconfig_test.go new file mode 100644 index 0000000000..98b7cef886 --- /dev/null +++ b/experimental/ssh/internal/sshconfig/sshconfig_test.go @@ -0,0 +1,212 @@ +package sshconfig + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/databricks/cli/libs/cmdio" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetConfigDir(t *testing.T) { + dir, err := GetConfigDir() + assert.NoError(t, err) + assert.Contains(t, dir, ".databricks/ssh-tunnel-configs") +} + +func TestGetMainConfigPath(t *testing.T) { + path, err := GetMainConfigPath() + assert.NoError(t, err) + assert.Contains(t, path, ".ssh/config") +} + +func TestGetMainConfigPathOrDefault(t *testing.T) { + path, err := GetMainConfigPathOrDefault("/custom/path") + assert.NoError(t, err) + assert.Equal(t, "/custom/path", path) + + path, err = GetMainConfigPathOrDefault("") + assert.NoError(t, err) + assert.Contains(t, path, ".ssh/config") +} + +func TestEnsureMainConfigExists(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".ssh", "config") + + err := EnsureMainConfigExists(configPath) + assert.NoError(t, err) + + _, err = os.Stat(filepath.Dir(configPath)) + assert.NoError(t, err) + + _, err = os.Stat(configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Empty(t, content) +} + +func TestEnsureIncludeDirective_NewConfig(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".ssh", "config") + + err := EnsureIncludeDirective(configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + + configStr := string(content) + assert.Contains(t, configStr, "Include") + assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") +} + +func TestEnsureIncludeDirective_AlreadyExists(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".ssh", "config") + + configDir, err := GetConfigDir() + require.NoError(t, err) + + existingContent := "Include " + configDir + "/*\n\nHost example\n User test\n" + err = os.MkdirAll(filepath.Dir(configPath), 0o700) + require.NoError(t, err) + err = os.WriteFile(configPath, []byte(existingContent), 0o600) + require.NoError(t, err) + + err = EnsureIncludeDirective(configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Equal(t, existingContent, string(content)) +} + +func TestEnsureIncludeDirective_PrependsToExisting(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".ssh", "config") + + existingContent := "Host example\n User test\n" + err := os.MkdirAll(filepath.Dir(configPath), 0o700) + require.NoError(t, err) + err = os.WriteFile(configPath, []byte(existingContent), 0o600) + require.NoError(t, err) + + err = EnsureIncludeDirective(configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + + configStr := string(content) + assert.Contains(t, configStr, "Include") + assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") + assert.Contains(t, configStr, "Host example") + + includeIndex := len("Include") + hostIndex := len(configStr) - len(existingContent) + assert.Less(t, includeIndex, hostIndex, "Include directive should come before existing content") +} + +func TestGetHostConfigPath(t *testing.T) { + path, err := GetHostConfigPath("test-host") + assert.NoError(t, err) + assert.Contains(t, path, ".databricks/ssh-tunnel-configs/test-host") +} + +func TestHostConfigExists(t *testing.T) { + tmpDir := t.TempDir() + homeDir := os.Getenv("HOME") + defer func() { os.Setenv("HOME", homeDir) }() + os.Setenv("HOME", tmpDir) + + exists, err := HostConfigExists("nonexistent") + assert.NoError(t, err) + assert.False(t, exists) + + configDir := filepath.Join(tmpDir, ConfigDirName) + err = os.MkdirAll(configDir, 0o700) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(configDir, "existing-host"), []byte("config"), 0o600) + require.NoError(t, err) + + exists, err = HostConfigExists("existing-host") + assert.NoError(t, err) + assert.True(t, exists) +} + +func TestCreateOrUpdateHostConfig_NewConfig(t *testing.T) { + ctx := cmdio.MockDiscard(context.Background()) + tmpDir := t.TempDir() + homeDir := os.Getenv("HOME") + defer func() { os.Setenv("HOME", homeDir) }() + os.Setenv("HOME", tmpDir) + + hostConfig := "Host test\n User root\n" + created, err := CreateOrUpdateHostConfig(ctx, "test-host", hostConfig, false) + assert.NoError(t, err) + assert.True(t, created) + + configPath, err := GetHostConfigPath("test-host") + require.NoError(t, err) + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Equal(t, hostConfig, string(content)) +} + +func TestCreateOrUpdateHostConfig_ExistingConfigNoRecreate(t *testing.T) { + ctx := cmdio.MockDiscard(context.Background()) + tmpDir := t.TempDir() + homeDir := os.Getenv("HOME") + defer func() { os.Setenv("HOME", homeDir) }() + os.Setenv("HOME", tmpDir) + + configDir := filepath.Join(tmpDir, ConfigDirName) + err := os.MkdirAll(configDir, 0o700) + require.NoError(t, err) + existingConfig := "Host test\n User admin\n" + err = os.WriteFile(filepath.Join(configDir, "test-host"), []byte(existingConfig), 0o600) + require.NoError(t, err) + + newConfig := "Host test\n User root\n" + created, err := CreateOrUpdateHostConfig(ctx, "test-host", newConfig, false) + assert.NoError(t, err) + assert.False(t, created) + + configPath, err := GetHostConfigPath("test-host") + require.NoError(t, err) + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Equal(t, existingConfig, string(content)) +} + +func TestCreateOrUpdateHostConfig_ExistingConfigWithRecreate(t *testing.T) { + ctx := cmdio.MockDiscard(context.Background()) + tmpDir := t.TempDir() + homeDir := os.Getenv("HOME") + defer func() { os.Setenv("HOME", homeDir) }() + os.Setenv("HOME", tmpDir) + + configDir := filepath.Join(tmpDir, ConfigDirName) + err := os.MkdirAll(configDir, 0o700) + require.NoError(t, err) + existingConfig := "Host test\n User admin\n" + err = os.WriteFile(filepath.Join(configDir, "test-host"), []byte(existingConfig), 0o600) + require.NoError(t, err) + + newConfig := "Host test\n User root\n" + created, err := CreateOrUpdateHostConfig(ctx, "test-host", newConfig, true) + assert.NoError(t, err) + assert.True(t, created) + + configPath, err := GetHostConfigPath("test-host") + require.NoError(t, err) + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Equal(t, newConfig, string(content)) +} From f0b080c6045f48794160053091d7e6ae32a18e44 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Thu, 29 Jan 2026 16:14:01 +0100 Subject: [PATCH 17/25] Always recreate SSH config entry --- experimental/ssh/internal/client/client.go | 17 ++--------------- .../ssh/internal/sshconfig/sshconfig.go | 1 - 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index e2beb2fc51..c3edb61e63 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -326,7 +326,6 @@ func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, k return fmt.Errorf("failed to generate ProxyCommand: %w", err) } - // Generate host config hostConfig := fmt.Sprintf(` Host %s User %s @@ -337,24 +336,12 @@ Host %s ProxyCommand %s `, hostName, userName, keyPath, proxyCommand) - // Check if the host config already exists - exists, err := sshconfig.HostConfigExists(hostName) + _, err = sshconfig.CreateOrUpdateHostConfig(ctx, hostName, hostConfig, true) if err != nil { return err } - if exists { - cmdio.LogString(ctx, fmt.Sprintf("SSH config entry for '%s' already exists, skipping", hostName)) - return nil - } - - // Create the host config file - _, err = sshconfig.CreateOrUpdateHostConfig(ctx, hostName, hostConfig, false) - if err != nil { - return err - } - - cmdio.LogString(ctx, fmt.Sprintf("Added SSH config entry for '%s'", hostName)) + cmdio.LogString(ctx, fmt.Sprintf("Updated SSH config entry for '%s'", hostName)) return nil } diff --git a/experimental/ssh/internal/sshconfig/sshconfig.go b/experimental/ssh/internal/sshconfig/sshconfig.go index 2b3e886c87..1d328d6041 100644 --- a/experimental/ssh/internal/sshconfig/sshconfig.go +++ b/experimental/ssh/internal/sshconfig/sshconfig.go @@ -151,7 +151,6 @@ func CreateOrUpdateHostConfig(ctx context.Context, hostName, hostConfig string, return false, nil } - // Ensure the config directory exists configDir := filepath.Dir(configPath) err = os.MkdirAll(configDir, 0o700) if err != nil { From 79f9ff54321dedfc3ab59a0b4171bec729994bcb Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Mon, 2 Feb 2026 11:40:32 +0100 Subject: [PATCH 18/25] Fix sshconfig logic and tests on Windows --- experimental/ssh/internal/setup/setup_test.go | 16 +++---- .../ssh/internal/sshconfig/sshconfig.go | 16 +++---- .../ssh/internal/sshconfig/sshconfig_test.go | 45 ++++++++++++------- 3 files changed, 41 insertions(+), 36 deletions(-) diff --git a/experimental/ssh/internal/setup/setup_test.go b/experimental/ssh/internal/setup/setup_test.go index 43e4f105b8..975828a3c8 100644 --- a/experimental/ssh/internal/setup/setup_test.go +++ b/experimental/ssh/internal/setup/setup_test.go @@ -205,9 +205,8 @@ func TestGenerateHostConfig_PathEscaping(t *testing.T) { func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) { ctx := cmdio.MockDiscard(context.Background()) tmpDir := t.TempDir() - homeDir := os.Getenv("HOME") - defer func() { os.Setenv("HOME", homeDir) }() - os.Setenv("HOME", tmpDir) + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) configPath := filepath.Join(tmpDir, "ssh_config") @@ -245,10 +244,11 @@ func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) { assert.NoError(t, err) configStr := string(content) assert.Contains(t, configStr, "Include") + // SSH config uses forward slashes on all platforms assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") // Check that host config file was created - hostConfigPath := filepath.Join(tmpDir, ".databricks/ssh-tunnel-configs/test-host") + hostConfigPath := filepath.Join(tmpDir, ".databricks", "ssh-tunnel-configs", "test-host") hostContent, err := os.ReadFile(hostConfigPath) assert.NoError(t, err) hostConfigStr := string(hostContent) @@ -260,9 +260,8 @@ func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) { func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) { ctx := cmdio.MockDiscard(context.Background()) tmpDir := t.TempDir() - homeDir := os.Getenv("HOME") - defer func() { os.Setenv("HOME", homeDir) }() - os.Setenv("HOME", tmpDir) + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) configPath := filepath.Join(tmpDir, "ssh_config") @@ -304,12 +303,13 @@ func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) { assert.NoError(t, err) configStr := string(content) assert.Contains(t, configStr, "Include") + // SSH config uses forward slashes on all platforms assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") assert.Contains(t, configStr, "# Existing SSH Config") assert.Contains(t, configStr, "Host existing-host") // Check that host config file was created - hostConfigPath := filepath.Join(tmpDir, ".databricks/ssh-tunnel-configs/new-host") + hostConfigPath := filepath.Join(tmpDir, ".databricks", "ssh-tunnel-configs", "new-host") hostContent, err := os.ReadFile(hostConfigPath) assert.NoError(t, err) hostConfigStr := string(hostContent) diff --git a/experimental/ssh/internal/sshconfig/sshconfig.go b/experimental/ssh/internal/sshconfig/sshconfig.go index 1d328d6041..968bd987a2 100644 --- a/experimental/ssh/internal/sshconfig/sshconfig.go +++ b/experimental/ssh/internal/sshconfig/sshconfig.go @@ -5,7 +5,6 @@ import ( "fmt" "os" "path/filepath" - "regexp" "strings" "github.com/databricks/cli/libs/cmdio" @@ -82,20 +81,15 @@ func EnsureIncludeDirective(configPath string) error { return fmt.Errorf("failed to read SSH config file: %w", err) } - // Check if Include directive already exists - includePattern := fmt.Sprintf(`(?m)^\s*Include\s+.*%s/\*\s*$`, regexp.QuoteMeta(ConfigDirName)) - matched, err := regexp.Match(includePattern, content) - if err != nil { - return fmt.Errorf("failed to check for existing Include directive: %w", err) - } + // Convert path to forward slashes for SSH config compatibility across platforms + configDirUnix := filepath.ToSlash(configDir) - if matched { + includeLine := fmt.Sprintf("Include %s/*", configDirUnix) + if strings.Contains(string(content), includeLine) { return nil } - // Prepend the Include directive - includeLine := fmt.Sprintf("Include %s/*\n", configDir) - newContent := includeLine + newContent := includeLine + "\n" if len(content) > 0 && !strings.HasPrefix(string(content), "\n") { newContent += "\n" } diff --git a/experimental/ssh/internal/sshconfig/sshconfig_test.go b/experimental/ssh/internal/sshconfig/sshconfig_test.go index 98b7cef886..e6a9f3fc4a 100644 --- a/experimental/ssh/internal/sshconfig/sshconfig_test.go +++ b/experimental/ssh/internal/sshconfig/sshconfig_test.go @@ -14,13 +14,13 @@ import ( func TestGetConfigDir(t *testing.T) { dir, err := GetConfigDir() assert.NoError(t, err) - assert.Contains(t, dir, ".databricks/ssh-tunnel-configs") + assert.Contains(t, dir, filepath.Join(".databricks", "ssh-tunnel-configs")) } func TestGetMainConfigPath(t *testing.T) { path, err := GetMainConfigPath() assert.NoError(t, err) - assert.Contains(t, path, ".ssh/config") + assert.Contains(t, path, filepath.Join(".ssh", "config")) } func TestGetMainConfigPathOrDefault(t *testing.T) { @@ -30,7 +30,7 @@ func TestGetMainConfigPathOrDefault(t *testing.T) { path, err = GetMainConfigPathOrDefault("") assert.NoError(t, err) - assert.Contains(t, path, ".ssh/config") + assert.Contains(t, path, filepath.Join(".ssh", "config")) } func TestEnsureMainConfigExists(t *testing.T) { @@ -55,6 +55,10 @@ func TestEnsureIncludeDirective_NewConfig(t *testing.T) { tmpDir := t.TempDir() configPath := filepath.Join(tmpDir, ".ssh", "config") + // Set home directory for test + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + err := EnsureIncludeDirective(configPath) assert.NoError(t, err) @@ -63,17 +67,23 @@ func TestEnsureIncludeDirective_NewConfig(t *testing.T) { configStr := string(content) assert.Contains(t, configStr, "Include") + // SSH config uses forward slashes on all platforms assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") } func TestEnsureIncludeDirective_AlreadyExists(t *testing.T) { tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + configPath := filepath.Join(tmpDir, ".ssh", "config") configDir, err := GetConfigDir() require.NoError(t, err) - existingContent := "Include " + configDir + "/*\n\nHost example\n User test\n" + // Use forward slashes as that's what SSH config uses + configDirUnix := filepath.ToSlash(configDir) + existingContent := "Include " + configDirUnix + "/*\n\nHost example\n User test\n" err = os.MkdirAll(filepath.Dir(configPath), 0o700) require.NoError(t, err) err = os.WriteFile(configPath, []byte(existingContent), 0o600) @@ -91,6 +101,10 @@ func TestEnsureIncludeDirective_PrependsToExisting(t *testing.T) { tmpDir := t.TempDir() configPath := filepath.Join(tmpDir, ".ssh", "config") + // Set home directory for test + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + existingContent := "Host example\n User test\n" err := os.MkdirAll(filepath.Dir(configPath), 0o700) require.NoError(t, err) @@ -105,6 +119,7 @@ func TestEnsureIncludeDirective_PrependsToExisting(t *testing.T) { configStr := string(content) assert.Contains(t, configStr, "Include") + // SSH config uses forward slashes on all platforms assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") assert.Contains(t, configStr, "Host example") @@ -116,14 +131,13 @@ func TestEnsureIncludeDirective_PrependsToExisting(t *testing.T) { func TestGetHostConfigPath(t *testing.T) { path, err := GetHostConfigPath("test-host") assert.NoError(t, err) - assert.Contains(t, path, ".databricks/ssh-tunnel-configs/test-host") + assert.Contains(t, path, filepath.Join(".databricks", "ssh-tunnel-configs", "test-host")) } func TestHostConfigExists(t *testing.T) { tmpDir := t.TempDir() - homeDir := os.Getenv("HOME") - defer func() { os.Setenv("HOME", homeDir) }() - os.Setenv("HOME", tmpDir) + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) exists, err := HostConfigExists("nonexistent") assert.NoError(t, err) @@ -143,9 +157,8 @@ func TestHostConfigExists(t *testing.T) { func TestCreateOrUpdateHostConfig_NewConfig(t *testing.T) { ctx := cmdio.MockDiscard(context.Background()) tmpDir := t.TempDir() - homeDir := os.Getenv("HOME") - defer func() { os.Setenv("HOME", homeDir) }() - os.Setenv("HOME", tmpDir) + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) hostConfig := "Host test\n User root\n" created, err := CreateOrUpdateHostConfig(ctx, "test-host", hostConfig, false) @@ -162,9 +175,8 @@ func TestCreateOrUpdateHostConfig_NewConfig(t *testing.T) { func TestCreateOrUpdateHostConfig_ExistingConfigNoRecreate(t *testing.T) { ctx := cmdio.MockDiscard(context.Background()) tmpDir := t.TempDir() - homeDir := os.Getenv("HOME") - defer func() { os.Setenv("HOME", homeDir) }() - os.Setenv("HOME", tmpDir) + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) configDir := filepath.Join(tmpDir, ConfigDirName) err := os.MkdirAll(configDir, 0o700) @@ -188,9 +200,8 @@ func TestCreateOrUpdateHostConfig_ExistingConfigNoRecreate(t *testing.T) { func TestCreateOrUpdateHostConfig_ExistingConfigWithRecreate(t *testing.T) { ctx := cmdio.MockDiscard(context.Background()) tmpDir := t.TempDir() - homeDir := os.Getenv("HOME") - defer func() { os.Setenv("HOME", homeDir) }() - os.Setenv("HOME", tmpDir) + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) configDir := filepath.Join(tmpDir, ConfigDirName) err := os.MkdirAll(configDir, 0o700) From 27dc6d4900d0d891793f693d1ab0b3de3a6df327 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Mon, 2 Feb 2026 15:14:21 +0100 Subject: [PATCH 19/25] Do not create secret scope if it already exists --- experimental/ssh/internal/keys/secrets.go | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/experimental/ssh/internal/keys/secrets.go b/experimental/ssh/internal/keys/secrets.go index eac692f235..d4e00d10ba 100644 --- a/experimental/ssh/internal/keys/secrets.go +++ b/experimental/ssh/internal/keys/secrets.go @@ -18,10 +18,23 @@ func CreateKeysSecretScope(ctx context.Context, client *databricks.WorkspaceClie return "", fmt.Errorf("failed to get current user: %w", err) } secretScopeName := fmt.Sprintf("%s-%s-ssh-tunnel-keys", me.UserName, sessionID) + + // Do not create the scope if it already exists. + // We can instead filter out "resource already exists" errors from CreateScope, + // but that API can also lead to "limit exceeded" errors, even if the scope does actually exist. + scope, err := client.Secrets.ListSecretsByScope(ctx, secretScopeName) + if err != nil && !errors.Is(err, databricks.ErrResourceDoesNotExist) { + return "", fmt.Errorf("failed to check if secret scope %s exists: %w", secretScopeName, err) + } + + if scope != nil && err == nil { + return secretScopeName, nil + } + err = client.Secrets.CreateScope(ctx, workspace.CreateScope{ Scope: secretScopeName, }) - if err != nil && !errors.Is(err, databricks.ErrResourceAlreadyExists) { + if err != nil { return "", fmt.Errorf("failed to create secrets scope: %w", err) } return secretScopeName, nil From a6712299da93d41c825877b15eac9c12d5cd3667 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Tue, 3 Feb 2026 16:24:31 +0100 Subject: [PATCH 20/25] Better SSHD config escaping --- experimental/ssh/internal/server/sshd.go | 17 +++-- experimental/ssh/internal/server/sshd_test.go | 63 +++++++++++++++++++ 2 files changed, 75 insertions(+), 5 deletions(-) create mode 100644 experimental/ssh/internal/server/sshd_test.go diff --git a/experimental/ssh/internal/server/sshd.go b/experimental/ssh/internal/server/sshd.go index 7a038e73b5..5be40c7a68 100644 --- a/experimental/ssh/internal/server/sshd.go +++ b/experimental/ssh/internal/server/sshd.go @@ -52,15 +52,13 @@ func prepareSSHDConfig(ctx context.Context, client *databricks.WorkspaceClient, return "", err } - // Set all available env vars, wrapping values in quotes and escaping quotes inside values + // Set all available env vars, wrapping values in quotes, escaping quotes, and stripping newlines setEnv := "SetEnv" for _, env := range os.Environ() { parts := strings.SplitN(env, "=", 2) - if len(parts) != 2 { - continue + if len(parts) == 2 { + setEnv += " " + parts[0] + "=\"" + escapeEnvValue(parts[1]) + "\"" } - valEscaped := strings.ReplaceAll(parts[1], "\"", "\\\"") - setEnv += " " + parts[0] + "=\"" + valEscaped + "\"" } setEnv += " DATABRICKS_CLI_UPSTREAM=databricks_ssh_tunnel" setEnv += " DATABRICKS_CLI_UPSTREAM_VERSION=" + opts.Version @@ -94,3 +92,12 @@ func prepareSSHDConfig(ctx context.Context, client *databricks.WorkspaceClient, func createSSHDProcess(ctx context.Context, configPath string) *exec.Cmd { return exec.CommandContext(ctx, "/usr/sbin/sshd", "-f", configPath, "-i") } + +// escapeEnvValue escapes a value for use in sshd SetEnv directive. +// It strips newlines and escapes quotes. +func escapeEnvValue(val string) string { + val = strings.ReplaceAll(val, "\r", "") + val = strings.ReplaceAll(val, "\n", "") + val = strings.ReplaceAll(val, "\"", "\\\"") + return val +} diff --git a/experimental/ssh/internal/server/sshd_test.go b/experimental/ssh/internal/server/sshd_test.go new file mode 100644 index 0000000000..ffcf9225b5 --- /dev/null +++ b/experimental/ssh/internal/server/sshd_test.go @@ -0,0 +1,63 @@ +package server + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEscapeEnvValue(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple value", + input: "hello", + expected: "hello", + }, + { + name: "value with quotes", + input: `say "hello"`, + expected: `say \"hello\"`, + }, + { + name: "value with newline", + input: "line1\nline2", + expected: "line1line2", + }, + { + name: "value with carriage return", + input: "line1\rline2", + expected: "line1line2", + }, + { + name: "value with CRLF", + input: "line1\r\nline2", + expected: "line1line2", + }, + { + name: "value with quotes and newlines", + input: "say \"hello\"\nworld", + expected: `say \"hello\"world`, + }, + { + name: "empty value", + input: "", + expected: "", + }, + { + name: "only newlines", + input: "\n\r\n", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := escapeEnvValue(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} From 9ab10c36f5cd996e86d8bbf48913a6a95fbe534d Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Thu, 5 Feb 2026 16:15:44 +0100 Subject: [PATCH 21/25] Avoid duplication of host config generation code --- experimental/ssh/internal/client/client.go | 10 +------ experimental/ssh/internal/setup/setup.go | 11 +------- .../ssh/internal/sshconfig/sshconfig.go | 27 ++++++++++--------- .../ssh/internal/sshconfig/sshconfig_test.go | 6 ++--- 4 files changed, 20 insertions(+), 34 deletions(-) diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index c3edb61e63..940f792f0e 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -326,15 +326,7 @@ func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, k return fmt.Errorf("failed to generate ProxyCommand: %w", err) } - hostConfig := fmt.Sprintf(` -Host %s - User %s - ConnectTimeout 360 - StrictHostKeyChecking accept-new - IdentitiesOnly yes - IdentityFile %q - ProxyCommand %s -`, hostName, userName, keyPath, proxyCommand) + hostConfig := sshconfig.GenerateHostConfig(hostName, userName, keyPath, proxyCommand) _, err = sshconfig.CreateOrUpdateHostConfig(ctx, hostName, hostConfig, true) if err != nil { diff --git a/experimental/ssh/internal/setup/setup.go b/experimental/ssh/internal/setup/setup.go index 3cd5d41182..99b5a68902 100644 --- a/experimental/ssh/internal/setup/setup.go +++ b/experimental/ssh/internal/setup/setup.go @@ -49,16 +49,7 @@ func generateHostConfig(opts SetupOptions) (string, error) { return "", fmt.Errorf("failed to get local keys folder: %w", err) } - hostConfig := fmt.Sprintf(` -Host %s - User root - ConnectTimeout 360 - StrictHostKeyChecking accept-new - IdentitiesOnly yes - IdentityFile %q - ProxyCommand %s -`, opts.HostName, identityFilePath, opts.ProxyCommand) - + hostConfig := sshconfig.GenerateHostConfig(opts.HostName, "root", identityFilePath, opts.ProxyCommand) return hostConfig, nil } diff --git a/experimental/ssh/internal/sshconfig/sshconfig.go b/experimental/ssh/internal/sshconfig/sshconfig.go index 968bd987a2..3a6713acbf 100644 --- a/experimental/ssh/internal/sshconfig/sshconfig.go +++ b/experimental/ssh/internal/sshconfig/sshconfig.go @@ -11,20 +11,18 @@ import ( ) const ( - // ConfigDirName is the directory name for Databricks SSH tunnel configs - ConfigDirName = ".databricks/ssh-tunnel-configs" + // configDirName is the directory name for Databricks SSH tunnel configs, relative to the user's home directory. + configDirName = ".databricks/ssh-tunnel-configs" ) -// GetConfigDir returns the path to the Databricks SSH tunnel configs directory. func GetConfigDir() (string, error) { homeDir, err := os.UserHomeDir() if err != nil { return "", fmt.Errorf("failed to get home directory: %w", err) } - return filepath.Join(homeDir, ConfigDirName), nil + return filepath.Join(homeDir, configDirName), nil } -// GetMainConfigPath returns the path to the main SSH config file. func GetMainConfigPath() (string, error) { homeDir, err := os.UserHomeDir() if err != nil { @@ -33,7 +31,6 @@ func GetMainConfigPath() (string, error) { return filepath.Join(homeDir, ".ssh", "config"), nil } -// GetMainConfigPathOrDefault returns the provided path if non-empty, otherwise returns the default. func GetMainConfigPathOrDefault(configPath string) (string, error) { if configPath != "" { return configPath, nil @@ -41,7 +38,6 @@ func GetMainConfigPathOrDefault(configPath string) (string, error) { return GetMainConfigPath() } -// EnsureMainConfigExists ensures the main SSH config file exists. func EnsureMainConfigExists(configPath string) error { _, err := os.Stat(configPath) if os.IsNotExist(err) { @@ -59,7 +55,6 @@ func EnsureMainConfigExists(configPath string) error { return err } -// EnsureIncludeDirective ensures the Include directive for Databricks configs exists in the main SSH config. func EnsureIncludeDirective(configPath string) error { configDir, err := GetConfigDir() if err != nil { @@ -103,7 +98,6 @@ func EnsureIncludeDirective(configPath string) error { return nil } -// GetHostConfigPath returns the path to a specific host's config file. func GetHostConfigPath(hostName string) (string, error) { configDir, err := GetConfigDir() if err != nil { @@ -112,7 +106,6 @@ func GetHostConfigPath(hostName string) (string, error) { return filepath.Join(configDir, hostName), nil } -// HostConfigExists checks if a host config file already exists. func HostConfigExists(hostName string) (bool, error) { configPath, err := GetHostConfigPath(hostName) if err != nil { @@ -128,7 +121,6 @@ func HostConfigExists(hostName string) (bool, error) { return true, nil } -// CreateOrUpdateHostConfig creates or updates a host config file. // Returns true if the config was created/updated, false if it was skipped. func CreateOrUpdateHostConfig(ctx context.Context, hostName, hostConfig string, recreate bool) (bool, error) { configPath, err := GetHostConfigPath(hostName) @@ -159,7 +151,6 @@ func CreateOrUpdateHostConfig(ctx context.Context, hostName, hostConfig string, return true, nil } -// PromptRecreateConfig asks the user if they want to recreate an existing config. func PromptRecreateConfig(ctx context.Context, hostName string) (bool, error) { response, err := cmdio.AskYesOrNo(ctx, fmt.Sprintf("Host '%s' already exists. Do you want to recreate the config?", hostName)) if err != nil { @@ -167,3 +158,15 @@ func PromptRecreateConfig(ctx context.Context, hostName string) (bool, error) { } return response, nil } + +func GenerateHostConfig(hostName, userName, identityFile, proxyCommand string) string { + return fmt.Sprintf(` +Host %s + User %s + ConnectTimeout 360 + StrictHostKeyChecking accept-new + IdentitiesOnly yes + IdentityFile %q + ProxyCommand %s +`, hostName, userName, identityFile, proxyCommand) +} diff --git a/experimental/ssh/internal/sshconfig/sshconfig_test.go b/experimental/ssh/internal/sshconfig/sshconfig_test.go index e6a9f3fc4a..5fa13923ee 100644 --- a/experimental/ssh/internal/sshconfig/sshconfig_test.go +++ b/experimental/ssh/internal/sshconfig/sshconfig_test.go @@ -143,7 +143,7 @@ func TestHostConfigExists(t *testing.T) { assert.NoError(t, err) assert.False(t, exists) - configDir := filepath.Join(tmpDir, ConfigDirName) + configDir := filepath.Join(tmpDir, configDirName) err = os.MkdirAll(configDir, 0o700) require.NoError(t, err) err = os.WriteFile(filepath.Join(configDir, "existing-host"), []byte("config"), 0o600) @@ -178,7 +178,7 @@ func TestCreateOrUpdateHostConfig_ExistingConfigNoRecreate(t *testing.T) { t.Setenv("HOME", tmpDir) t.Setenv("USERPROFILE", tmpDir) - configDir := filepath.Join(tmpDir, ConfigDirName) + configDir := filepath.Join(tmpDir, configDirName) err := os.MkdirAll(configDir, 0o700) require.NoError(t, err) existingConfig := "Host test\n User admin\n" @@ -203,7 +203,7 @@ func TestCreateOrUpdateHostConfig_ExistingConfigWithRecreate(t *testing.T) { t.Setenv("HOME", tmpDir) t.Setenv("USERPROFILE", tmpDir) - configDir := filepath.Join(tmpDir, ConfigDirName) + configDir := filepath.Join(tmpDir, configDirName) err := os.MkdirAll(configDir, 0o700) require.NoError(t, err) existingConfig := "Host test\n User admin\n" From 1595782453d0de182df54b95212b5b90a6063dfd Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Fri, 6 Feb 2026 14:53:10 +0100 Subject: [PATCH 22/25] Escape backslashes in SetEnv values --- experimental/ssh/internal/server/sshd.go | 3 ++- experimental/ssh/internal/server/sshd_test.go | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/experimental/ssh/internal/server/sshd.go b/experimental/ssh/internal/server/sshd.go index 5be40c7a68..c8f23d02a5 100644 --- a/experimental/ssh/internal/server/sshd.go +++ b/experimental/ssh/internal/server/sshd.go @@ -94,10 +94,11 @@ func createSSHDProcess(ctx context.Context, configPath string) *exec.Cmd { } // escapeEnvValue escapes a value for use in sshd SetEnv directive. -// It strips newlines and escapes quotes. +// It strips newlines and escapes backslashes and quotes. func escapeEnvValue(val string) string { val = strings.ReplaceAll(val, "\r", "") val = strings.ReplaceAll(val, "\n", "") + val = strings.ReplaceAll(val, "\\", "\\\\") val = strings.ReplaceAll(val, "\"", "\\\"") return val } diff --git a/experimental/ssh/internal/server/sshd_test.go b/experimental/ssh/internal/server/sshd_test.go index ffcf9225b5..a453d987a0 100644 --- a/experimental/ssh/internal/server/sshd_test.go +++ b/experimental/ssh/internal/server/sshd_test.go @@ -52,6 +52,16 @@ func TestEscapeEnvValue(t *testing.T) { input: "\n\r\n", expected: "", }, + { + name: "backslashes", + input: `foo\bar\`, + expected: `foo\\bar\\`, + }, + { + name: "backslash before quote", + input: `foo\"bar`, + expected: `foo\\\"bar`, + }, } for _, tt := range tests { From cbf2cdbc01f7c5506b244cc9a1af57b225dec014 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Thu, 5 Feb 2026 12:23:59 +0100 Subject: [PATCH 23/25] Initialize Spark Connect session in Jupyter init script --- experimental/ssh/internal/server/jupyter-init.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/experimental/ssh/internal/server/jupyter-init.py b/experimental/ssh/internal/server/jupyter-init.py index 3e58b2f94a..3e1ac6937f 100644 --- a/experimental/ssh/internal/server/jupyter-init.py +++ b/experimental/ssh/internal/server/jupyter-init.py @@ -185,7 +185,16 @@ def df_html(df: DataFrame) -> str: html_formatter.for_type(SparkConnectDataframe, df_html) html_formatter.for_type(DataFrame, df_html) +@_log_exceptions +def _initialize_spark_connect_session(): + import os + from dbruntime.spark_connection import get_and_configure_uds_spark + os.environ["SPARK_REMOTE"] = "unix:///databricks/sparkconnect/grpc.sock" + spark = get_and_configure_uds_spark() + globals()["spark"] = spark + _register_magics() _register_formatters() _register_runtime_hooks() +_initialize_spark_connect_session() From d606086bdcd6ed237aadb04cdc199c81e42e267f Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Thu, 5 Feb 2026 16:02:01 +0100 Subject: [PATCH 24/25] Format jupyter-init.py --- experimental/ssh/internal/server/jupyter-init.py | 1 + 1 file changed, 1 insertion(+) diff --git a/experimental/ssh/internal/server/jupyter-init.py b/experimental/ssh/internal/server/jupyter-init.py index 3e1ac6937f..ee21e22e98 100644 --- a/experimental/ssh/internal/server/jupyter-init.py +++ b/experimental/ssh/internal/server/jupyter-init.py @@ -185,6 +185,7 @@ def df_html(df: DataFrame) -> str: html_formatter.for_type(SparkConnectDataframe, df_html) html_formatter.for_type(DataFrame, df_html) + @_log_exceptions def _initialize_spark_connect_session(): import os From 4dc0faa671aac6dfe52106769a19c4b4ffa520c9 Mon Sep 17 00:00:00 2001 From: Ilia Babanov Date: Tue, 10 Feb 2026 12:28:28 +0100 Subject: [PATCH 25/25] Use DatabricksSession for Spark Connect session initialization --- .../ssh/internal/server/jupyter-init.py | 34 +++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/experimental/ssh/internal/server/jupyter-init.py b/experimental/ssh/internal/server/jupyter-init.py index ee21e22e98..0cd463a581 100644 --- a/experimental/ssh/internal/server/jupyter-init.py +++ b/experimental/ssh/internal/server/jupyter-init.py @@ -187,7 +187,7 @@ def df_html(df: DataFrame) -> str: @_log_exceptions -def _initialize_spark_connect_session(): +def _initialize_spark_connect_session_grpc(): import os from dbruntime.spark_connection import get_and_configure_uds_spark os.environ["SPARK_REMOTE"] = "unix:///databricks/sparkconnect/grpc.sock" @@ -195,7 +195,37 @@ def _initialize_spark_connect_session(): globals()["spark"] = spark +@_log_exceptions +def _initialize_spark_connect_session_dbconnect(): + import IPython + from databricks.connect import DatabricksSession + user_ns = getattr(IPython.get_ipython(), "user_ns", {}) + existing_session = getattr(user_ns, "spark", None) + if existing_session is not None and _is_spark_connect(existing_session): + return + try: + # Clear the existing local spark session, otherwise DatabricksSession will re-use it. + user_ns["spark"] = None + globals()["spark"] = None + # DatabricksSession will use the existing env vars for the connection. + spark_session = DatabricksSession.builder.getOrCreate() + user_ns["spark"] = spark_session + globals()["spark"] = spark_session + except Exception as e: + user_ns["spark"] = existing_session + globals()["spark"] = existing_session + raise e + + +def _is_spark_connect(session) -> bool: + try: + from pyspark.sql.connect.session import SparkSession as ConnectSparkSession + return isinstance(session, ConnectSparkSession) + except ImportError: + return False + + _register_magics() _register_formatters() _register_runtime_hooks() -_initialize_spark_connect_session() +_initialize_spark_connect_session_dbconnect()