diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go index 1dc9c22337..4eca1aee7b 100644 --- a/experimental/ssh/cmd/connect.go +++ b/experimental/ssh/cmd/connect.go @@ -1,6 +1,7 @@ package ssh import ( + "errors" "time" "github.com/databricks/cli/cmd/root" @@ -22,7 +23,10 @@ the SSH server and handling the connection proxy. } var clusterID string + var connectionName string + var accelerator string var proxyMode bool + var ide string var serverMetadata string var shutdownDelay time.Duration var maxClients int @@ -30,13 +34,20 @@ the SSH server and handling the connection proxy. var releasesDir string var autoStartCluster bool var userKnownHostsFile string + var liteswap string - cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (required)") - cmd.MarkFlagRequired("cluster") + cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)") cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects") cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients") cmd.Flags().BoolVar(&autoStartCluster, "auto-start-cluster", true, "Automatically start the cluster if it is not running") + cmd.Flags().StringVar(&connectionName, "name", "", "Connection name (for serverless compute)") + cmd.Flags().MarkHidden("name") + cmd.Flags().StringVar(&accelerator, "accelerator", "", "GPU accelerator type (GPU_1xA10 or GPU_8xH100)") + cmd.Flags().MarkHidden("accelerator") + cmd.Flags().StringVar(&ide, "ide", "", "Open remote IDE window (vscode or cursor)") + cmd.Flags().MarkHidden("ide") + cmd.Flags().BoolVar(&proxyMode, "proxy", false, "ProxyCommand mode") cmd.Flags().MarkHidden("proxy") cmd.Flags().StringVar(&serverMetadata, "metadata", "", "Metadata of the running SSH server (format: ,)") @@ -50,6 +61,9 @@ the SSH server and handling the connection proxy. cmd.Flags().StringVar(&userKnownHostsFile, "user-known-hosts-file", "", "Path to user known hosts file for SSH client") cmd.Flags().MarkHidden("user-known-hosts-file") + cmd.Flags().StringVar(&liteswap, "liteswap", "", "Liteswap header value for traffic routing (dev/test only)") + cmd.Flags().MarkHidden("liteswap") + cmd.PreRunE = func(cmd *cobra.Command, args []string) error { // CLI in the proxy mode is executed by the ssh client and can't prompt for input if proxyMode { @@ -64,20 +78,41 @@ the SSH server and handling the connection proxy. cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() wsClient := cmdctx.WorkspaceClient(ctx) + + if !proxyMode && clusterID == "" && connectionName == "" { + return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)") + } + + if accelerator != "" && connectionName == "" { + return errors.New("--accelerator flag can only be used with serverless compute (--name flag)") + } + + // Remove when we add support for serverless CPU + if connectionName != "" && accelerator == "" { + return errors.New("--name flag requires --accelerator to be set (for now we only support serverless GPU compute)") + } + + // TODO: validate connectionName if provided + opts := client.ClientOptions{ Profile: wsClient.Config.Profile, ClusterID: clusterID, + ConnectionName: connectionName, + Accelerator: accelerator, ProxyMode: proxyMode, + IDE: ide, ServerMetadata: serverMetadata, ShutdownDelay: shutdownDelay, MaxClients: maxClients, HandoverTimeout: handoverTimeout, ReleasesDir: releasesDir, ServerTimeout: max(serverTimeout, shutdownDelay), + TaskStartupTimeout: taskStartupTimeout, AutoStartCluster: autoStartCluster, ClientPublicKeyName: clientPublicKeyName, ClientPrivateKeyName: clientPrivateKeyName, UserKnownHostsFile: userKnownHostsFile, + Liteswap: liteswap, AdditionalArgs: args, } return client.Run(ctx, wsClient, opts) diff --git a/experimental/ssh/cmd/constants.go b/experimental/ssh/cmd/constants.go index e812726845..db9f5ccc70 100644 --- a/experimental/ssh/cmd/constants.go +++ b/experimental/ssh/cmd/constants.go @@ -9,6 +9,7 @@ const ( defaultHandoverTimeout = 30 * time.Minute serverTimeout = 24 * time.Hour + taskStartupTimeout = 10 * time.Minute serverPortRange = 100 serverConfigDir = ".ssh-tunnel" serverPrivateKeyName = "server-private-key" diff --git a/experimental/ssh/cmd/server.go b/experimental/ssh/cmd/server.go index 77b8c1c156..efe283f28a 100644 --- a/experimental/ssh/cmd/server.go +++ b/experimental/ssh/cmd/server.go @@ -26,12 +26,15 @@ and proxies them to local SSH daemon processes. var maxClients int var shutdownDelay time.Duration var clusterID string + var sessionID string var version string var secretScopeName string var authorizedKeySecretName string cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID") cmd.MarkFlagRequired("cluster") + cmd.Flags().StringVar(&sessionID, "session-id", "", "Session identifier (cluster ID or serverless connection name)") + cmd.MarkFlagRequired("session-id") cmd.Flags().StringVar(&secretScopeName, "secret-scope-name", "", "Databricks secret scope name to store SSH keys") cmd.MarkFlagRequired("secret-scope-name") cmd.Flags().StringVar(&authorizedKeySecretName, "authorized-key-secret-name", "", "Name of the secret containing the client public key") @@ -56,6 +59,7 @@ and proxies them to local SSH daemon processes. wsc := cmdctx.WorkspaceClient(ctx) opts := server.ServerOptions{ ClusterID: clusterID, + SessionID: sessionID, MaxClients: maxClients, ShutdownDelay: shutdownDelay, Version: version, diff --git a/experimental/ssh/cmd/setup.go b/experimental/ssh/cmd/setup.go index 3e4523904c..81b7863666 100644 --- a/experimental/ssh/cmd/setup.go +++ b/experimental/ssh/cmd/setup.go @@ -1,9 +1,11 @@ package ssh import ( + "fmt" "time" "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/experimental/ssh/internal/client" "github.com/databricks/cli/experimental/ssh/internal/setup" "github.com/databricks/cli/libs/cmdctx" "github.com/spf13/cobra" @@ -43,16 +45,27 @@ an SSH host configuration to your SSH config file. cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - client := cmdctx.WorkspaceClient(ctx) - opts := setup.SetupOptions{ + wsClient := cmdctx.WorkspaceClient(ctx) + setupOpts := setup.SetupOptions{ HostName: hostName, ClusterID: clusterID, AutoStartCluster: autoStartCluster, SSHConfigPath: sshConfigPath, ShutdownDelay: shutdownDelay, - Profile: client.Config.Profile, + Profile: wsClient.Config.Profile, } - return setup.Setup(ctx, client, opts) + clientOpts := client.ClientOptions{ + ClusterID: setupOpts.ClusterID, + AutoStartCluster: setupOpts.AutoStartCluster, + ShutdownDelay: setupOpts.ShutdownDelay, + Profile: setupOpts.Profile, + } + proxyCommand, err := clientOpts.ToProxyCommand() + if err != nil { + return fmt.Errorf("failed to generate ProxyCommand: %w", err) + } + setupOpts.ProxyCommand = proxyCommand + return setup.Setup(ctx, wsClient, setupOpts) } return cmd diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index a3c3d78889..940f792f0e 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -19,11 +19,13 @@ import ( "github.com/databricks/cli/experimental/ssh/internal/keys" "github.com/databricks/cli/experimental/ssh/internal/proxy" - "github.com/databricks/cli/experimental/ssh/internal/setup" + "github.com/databricks/cli/experimental/ssh/internal/sshconfig" sshWorkspace "github.com/databricks/cli/experimental/ssh/internal/workspace" "github.com/databricks/cli/internal/build" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/log" "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/retries" "github.com/databricks/databricks-sdk-go/service/compute" "github.com/databricks/databricks-sdk-go/service/jobs" "github.com/databricks/databricks-sdk-go/service/workspace" @@ -35,9 +37,23 @@ var sshServerBootstrapScript string var errServerMetadata = errors.New("server metadata error") +const ( + sshServerTaskKey = "start_ssh_server" + serverlessEnvironmentKey = "ssh_tunnel_serverless" + + VSCodeOption = "vscode" + VSCodeCommand = "code" + CursorOption = "cursor" + CursorCommand = "cursor" +) + type ClientOptions struct { - // Id of the cluster to connect to + // Id of the cluster to connect to (for dedicated clusters) ClusterID string + // Connection name (for serverless compute). Used as unique identifier instead of ClusterID. + ConnectionName string + // GPU accelerator type (for serverless compute) + Accelerator string // Delay before shutting down the server after the last client disconnects ShutdownDelay time.Duration // Maximum number of SSH clients @@ -46,13 +62,17 @@ type ClientOptions struct { // to the cluster and proxy all traffic through stdin/stdout. // In the non proxy mode the CLI spawns an ssh client with the ProxyCommand config. ProxyMode bool - // Expected format: ",". + // Open remote IDE window with a specific ssh config (empty, 'vscode', or 'cursor') + IDE string + // Expected format: ",,". // If present, the CLI won't attempt to start the server. ServerMetadata string // How often the CLI should reconnect to the server with new auth. HandoverTimeout time.Duration // Max amount of time the server process is allowed to live ServerTimeout time.Duration + // Max amount of time to wait for the SSH server task to reach RUNNING state + TaskStartupTimeout time.Duration // Directory for local SSH tunnel development releases. // If not present, the CLI will use github releases with the current version. ReleasesDir string @@ -70,6 +90,73 @@ type ClientOptions struct { AdditionalArgs []string // Optional path to the user known hosts file. UserKnownHostsFile string + // Liteswap header value for traffic routing (dev/test only). + Liteswap string +} + +func (o *ClientOptions) IsServerlessMode() bool { + return o.ClusterID == "" && o.ConnectionName != "" +} + +// SessionIdentifier returns the unique identifier for the session. +// For dedicated clusters, this is the cluster ID. For serverless, this is the connection name. +func (o *ClientOptions) SessionIdentifier() string { + if o.IsServerlessMode() { + return o.ConnectionName + } + return o.ClusterID +} + +// FormatMetadata formats the server metadata string for use in ProxyCommand. +// Returns empty string if userName is empty or serverPort is zero. +func FormatMetadata(userName string, serverPort int, clusterID string) string { + if userName == "" || serverPort == 0 { + return "" + } + if clusterID != "" { + return fmt.Sprintf("%s,%d,%s", userName, serverPort, clusterID) + } + return fmt.Sprintf("%s,%d", userName, serverPort) +} + +// ToProxyCommand generates the ProxyCommand string for SSH config. +// This method serializes the ClientOptions into a command-line invocation that will +// be parsed back into ClientOptions when the SSH ProxyCommand is executed. +func (o *ClientOptions) ToProxyCommand() (string, error) { + executablePath, err := os.Executable() + if err != nil { + return "", fmt.Errorf("failed to get current executable path: %w", err) + } + + var proxyCommand string + if o.IsServerlessMode() { + proxyCommand = fmt.Sprintf("%q ssh connect --proxy --name=%s --shutdown-delay=%s", + executablePath, o.ConnectionName, o.ShutdownDelay.String()) + if o.Accelerator != "" { + proxyCommand += " --accelerator=" + o.Accelerator + } + } else { + proxyCommand = fmt.Sprintf("%q ssh connect --proxy --cluster=%s --auto-start-cluster=%t --shutdown-delay=%s", + executablePath, o.ClusterID, o.AutoStartCluster, o.ShutdownDelay.String()) + } + + if o.ServerMetadata != "" { + proxyCommand += " --metadata=" + o.ServerMetadata + } + + if o.HandoverTimeout > 0 { + proxyCommand += " --handover-timeout=" + o.HandoverTimeout.String() + } + + if o.Profile != "" { + proxyCommand += " --profile=" + o.Profile + } + + if o.Liteswap != "" { + proxyCommand += " --liteswap=" + o.Liteswap + } + + return proxyCommand, nil } func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOptions) error { @@ -84,22 +171,30 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt cancel() }() - err := checkClusterState(ctx, client, opts.ClusterID, opts.AutoStartCluster) - if err != nil { - return err + sessionID := opts.SessionIdentifier() + if sessionID == "" { + return errors.New("either --cluster or --name must be provided") + } + + // Only check cluster state for dedicated clusters + if !opts.IsServerlessMode() { + err := checkClusterState(ctx, client, opts.ClusterID, opts.AutoStartCluster) + if err != nil { + return err + } } - secretScopeName, err := keys.CreateKeysSecretScope(ctx, client, opts.ClusterID) + secretScopeName, err := keys.CreateKeysSecretScope(ctx, client, sessionID) if err != nil { return fmt.Errorf("failed to create secret scope: %w", err) } - privateKeyBytes, publicKeyBytes, err := keys.CheckAndGenerateSSHKeyPairFromSecrets(ctx, client, opts.ClusterID, secretScopeName, opts.ClientPrivateKeyName, opts.ClientPublicKeyName) + privateKeyBytes, publicKeyBytes, err := keys.CheckAndGenerateSSHKeyPairFromSecrets(ctx, client, secretScopeName, opts.ClientPrivateKeyName, opts.ClientPublicKeyName) if err != nil { return fmt.Errorf("failed to get or generate SSH key pair from secrets: %w", err) } - keyPath, err := keys.GetLocalSSHKeyPath(opts.ClusterID, opts.SSHKeysDir) + keyPath, err := keys.GetLocalSSHKeyPath(sessionID, opts.SSHKeysDir) if err != nil { return fmt.Errorf("failed to get local keys folder: %w", err) } @@ -113,6 +208,7 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt var userName string var serverPort int + var clusterID string version := build.GetInfo().Version @@ -121,14 +217,15 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt if err := UploadTunnelReleases(ctx, client, version, opts.ReleasesDir); err != nil { return fmt.Errorf("failed to upload ssh-tunnel binaries: %w", err) } - userName, serverPort, err = ensureSSHServerIsRunning(ctx, client, version, secretScopeName, opts) + userName, serverPort, clusterID, err = ensureSSHServerIsRunning(ctx, client, version, secretScopeName, opts) if err != nil { return fmt.Errorf("failed to ensure that ssh server is running: %w", err) } } else { + // Metadata format: ",," metadata := strings.Split(opts.ServerMetadata, ",") - if len(metadata) != 2 { - return fmt.Errorf("invalid metadata: %s, expected format: ,", opts.ServerMetadata) + if len(metadata) < 2 { + return fmt.Errorf("invalid metadata: %s, expected format: ,[,]", opts.ServerMetadata) } userName = metadata[0] if userName == "" { @@ -138,65 +235,178 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt if err != nil { return fmt.Errorf("cannot parse port from metadata: %s, %w", opts.ServerMetadata, err) } + if len(metadata) >= 3 { + clusterID = metadata[2] + } else { + clusterID = opts.ClusterID + } + } + + // For serverless mode, we need the cluster ID from metadata for Driver Proxy connections + if opts.IsServerlessMode() && clusterID == "" { + return errors.New("cluster ID is required for serverless connections but was not found in metadata") } cmdio.LogString(ctx, "Remote user name: "+userName) cmdio.LogString(ctx, fmt.Sprintf("Server port: %d", serverPort)) + if opts.IsServerlessMode() { + cmdio.LogString(ctx, "Cluster ID (from serverless job): "+clusterID) + } if opts.ProxyMode { - return runSSHProxy(ctx, client, serverPort, opts) + return runSSHProxy(ctx, client, serverPort, clusterID, opts) + } else if opts.IDE != "" { + return runIDE(ctx, client, userName, keyPath, serverPort, clusterID, opts) } else { cmdio.LogString(ctx, fmt.Sprintf("Additional SSH arguments: %v", opts.AdditionalArgs)) - return spawnSSHClient(ctx, userName, keyPath, serverPort, opts) + return spawnSSHClient(ctx, userName, keyPath, serverPort, clusterID, opts) + } +} + +func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error { + if opts.IDE != VSCodeOption && opts.IDE != CursorOption { + return fmt.Errorf("invalid IDE value: %s, expected '%s' or '%s'", opts.IDE, VSCodeOption, CursorOption) + } + + connectionName := opts.SessionIdentifier() + if connectionName == "" { + return errors.New("connection name is required for IDE integration") + } + + // Get Databricks user name for the workspace path + currentUser, err := client.CurrentUser.Me(ctx) + if err != nil { + return fmt.Errorf("failed to get current user: %w", err) + } + databricksUserName := currentUser.UserName + + // Ensure SSH config entry exists + configPath, err := sshconfig.GetMainConfigPath() + if err != nil { + return fmt.Errorf("failed to get SSH config path: %w", err) + } + + err = ensureSSHConfigEntry(ctx, configPath, connectionName, userName, keyPath, serverPort, clusterID, opts) + if err != nil { + return fmt.Errorf("failed to ensure SSH config entry: %w", err) } + + ideCommand := VSCodeCommand + if opts.IDE == CursorOption { + ideCommand = CursorCommand + } + + // Construct the remote SSH URI + // Format: ssh-remote+@ /Workspace/Users// + remoteURI := fmt.Sprintf("ssh-remote+%s@%s", userName, connectionName) + remotePath := fmt.Sprintf("/Workspace/Users/%s/", databricksUserName) + + cmdio.LogString(ctx, fmt.Sprintf("Launching %s with remote URI: %s and path: %s", opts.IDE, remoteURI, remotePath)) + + ideCmd := exec.CommandContext(ctx, ideCommand, "--remote", remoteURI, remotePath) + ideCmd.Stdout = os.Stdout + ideCmd.Stderr = os.Stderr + + return ideCmd.Run() } -func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, clusterID, version string) (int, string, error) { - serverPort, err := sshWorkspace.GetWorkspaceMetadata(ctx, client, version, clusterID) +func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error { + // Ensure the Include directive exists in the main SSH config + err := sshconfig.EnsureIncludeDirective(configPath) if err != nil { - return 0, "", errors.Join(errServerMetadata, err) + return err } + + // Generate ProxyCommand with server metadata + optsWithMetadata := opts + optsWithMetadata.ServerMetadata = FormatMetadata(userName, serverPort, clusterID) + + proxyCommand, err := optsWithMetadata.ToProxyCommand() + if err != nil { + return fmt.Errorf("failed to generate ProxyCommand: %w", err) + } + + hostConfig := sshconfig.GenerateHostConfig(hostName, userName, keyPath, proxyCommand) + + _, err = sshconfig.CreateOrUpdateHostConfig(ctx, hostName, hostConfig, true) + if err != nil { + return err + } + + cmdio.LogString(ctx, fmt.Sprintf("Updated SSH config entry for '%s'", hostName)) + return nil +} + +// getServerMetadata retrieves the server metadata from the workspace and validates it via Driver Proxy. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +// For dedicated clusters, clusterID should be the same as sessionID. +// For serverless, clusterID is read from the workspace metadata. +func getServerMetadata(ctx context.Context, client *databricks.WorkspaceClient, sessionID, clusterID, version, liteswap string) (int, string, string, error) { + wsMetadata, err := sshWorkspace.GetWorkspaceMetadata(ctx, client, version, sessionID) + if err != nil { + return 0, "", "", errors.Join(errServerMetadata, err) + } + log.Debugf(ctx, "Workspace metadata: %+v", wsMetadata) + + // For serverless mode, the cluster ID comes from the metadata + effectiveClusterID := clusterID + if wsMetadata.ClusterID != "" { + effectiveClusterID = wsMetadata.ClusterID + } + + if effectiveClusterID == "" { + return 0, "", "", errors.Join(errServerMetadata, errors.New("cluster ID not available in metadata")) + } + workspaceID, err := client.CurrentWorkspaceID(ctx) if err != nil { - return 0, "", err + return 0, "", "", err } - metadataURL := fmt.Sprintf("%s/driver-proxy-api/o/%d/%s/%d/metadata", client.Config.Host, workspaceID, clusterID, serverPort) + metadataURL := fmt.Sprintf("%s/driver-proxy-api/o/%d/%s/%d/metadata", client.Config.Host, workspaceID, effectiveClusterID, wsMetadata.Port) + log.Debugf(ctx, "Metadata URL: %s", metadataURL) req, err := http.NewRequestWithContext(ctx, "GET", metadataURL, nil) if err != nil { - return 0, "", err + return 0, "", "", err + } + if liteswap != "" { + req.Header.Set("x-databricks-traffic-id", "testenv://liteswap/"+liteswap) } if err := client.Config.Authenticate(req); err != nil { - return 0, "", err + return 0, "", "", err } resp, err := http.DefaultClient.Do(req) if err != nil { - return 0, "", err + return 0, "", "", err } defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return 0, "", errors.Join(errServerMetadata, fmt.Errorf("server is not ok, status code %d", resp.StatusCode)) - } - bodyBytes, err := io.ReadAll(resp.Body) if err != nil { - return 0, "", err + return 0, "", "", err } - return serverPort, string(bodyBytes), nil + log.Debugf(ctx, "Metadata response: %s", string(bodyBytes)) + log.Debugf(ctx, "Metadata response status code: %d", resp.StatusCode) + + if resp.StatusCode != http.StatusOK { + return 0, "", "", errors.Join(errServerMetadata, fmt.Errorf("server is not ok, status code %d", resp.StatusCode)) + } + + return wsMetadata.Port, string(bodyBytes), effectiveClusterID, nil } -func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (int64, error) { - contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, opts.ClusterID) +func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) error { + sessionID := opts.SessionIdentifier() + contentDir, err := sshWorkspace.GetWorkspaceContentDir(ctx, client, version, sessionID) if err != nil { - return 0, fmt.Errorf("failed to get workspace content directory: %w", err) + return fmt.Errorf("failed to get workspace content directory: %w", err) } err = client.Workspace.MkdirsByPath(ctx, contentDir) if err != nil { - return 0, fmt.Errorf("failed to create directory in the remote workspace: %w", err) + return fmt.Errorf("failed to create directory in the remote workspace: %w", err) } - sshTunnelJobName := "ssh-server-bootstrap-" + opts.ClusterID + sshTunnelJobName := "ssh-server-bootstrap-" + sessionID jobNotebookPath := filepath.ToSlash(filepath.Join(contentDir, "ssh-server-bootstrap")) notebookContent := "# Databricks notebook source\n" + sshServerBootstrapScript encodedContent := base64.StdEncoding.EncodeToString([]byte(notebookContent)) @@ -209,46 +419,80 @@ func submitSSHTunnelJob(ctx context.Context, client *databricks.WorkspaceClient, Overwrite: true, }) if err != nil { - return 0, fmt.Errorf("failed to create ssh-tunnel notebook: %w", err) + return fmt.Errorf("failed to create ssh-tunnel notebook: %w", err) + } + + baseParams := map[string]string{ + "version": version, + "secretScopeName": secretScopeName, + "authorizedKeySecretName": opts.ClientPublicKeyName, + "shutdownDelay": opts.ShutdownDelay.String(), + "maxClients": strconv.Itoa(opts.MaxClients), + "sessionId": sessionID, + } + + cmdio.LogString(ctx, "Submitting a job to start the ssh server...") + + task := jobs.SubmitTask{ + TaskKey: sshServerTaskKey, + NotebookTask: &jobs.NotebookTask{ + NotebookPath: jobNotebookPath, + BaseParameters: baseParams, + }, + TimeoutSeconds: int(opts.ServerTimeout.Seconds()), + } + + if opts.IsServerlessMode() { + task.EnvironmentKey = serverlessEnvironmentKey + if opts.Accelerator != "" { + cmdio.LogString(ctx, "Using accelerator: "+opts.Accelerator) + task.Compute = &jobs.Compute{ + HardwareAccelerator: compute.HardwareAcceleratorType(opts.Accelerator), + } + } + } else { + task.ExistingClusterId = opts.ClusterID } - submitRun := jobs.SubmitRun{ + submitRequest := jobs.SubmitRun{ RunName: sshTunnelJobName, TimeoutSeconds: int(opts.ServerTimeout.Seconds()), - Tasks: []jobs.SubmitTask{ + Tasks: []jobs.SubmitTask{task}, + } + + if opts.IsServerlessMode() { + submitRequest.Environments = []jobs.JobEnvironment{ { - TaskKey: "start_ssh_server", - NotebookTask: &jobs.NotebookTask{ - NotebookPath: jobNotebookPath, - BaseParameters: map[string]string{ - "version": version, - "secretScopeName": secretScopeName, - "authorizedKeySecretName": opts.ClientPublicKeyName, - "shutdownDelay": opts.ShutdownDelay.String(), - "maxClients": strconv.Itoa(opts.MaxClients), - }, + EnvironmentKey: serverlessEnvironmentKey, + Spec: &compute.Environment{ + EnvironmentVersion: "3", }, - TimeoutSeconds: int(opts.ServerTimeout.Seconds()), - ExistingClusterId: opts.ClusterID, }, - }, + } } - cmdio.LogString(ctx, "Submitting a job to start the ssh server...") - runResult, err := client.Jobs.Submit(ctx, submitRun) + waiter, err := client.Jobs.Submit(ctx, submitRequest) if err != nil { - return 0, fmt.Errorf("failed to submit job: %w", err) + return fmt.Errorf("failed to submit job: %w", err) } - return runResult.Response.RunId, nil + cmdio.LogString(ctx, fmt.Sprintf("Job submitted successfully with run ID: %d", waiter.RunId)) + + return waitForJobToStart(ctx, client, waiter.RunId, opts.TaskStartupTimeout) } -func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, serverPort int, opts ClientOptions) error { - proxyCommand, err := setup.GenerateProxyCommand(opts.ClusterID, opts.AutoStartCluster, opts.ShutdownDelay, opts.Profile, userName, serverPort, opts.HandoverTimeout) +func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, serverPort int, clusterID string, opts ClientOptions) error { + // Create a copy with metadata for the ProxyCommand + optsWithMetadata := opts + optsWithMetadata.ServerMetadata = FormatMetadata(userName, serverPort, clusterID) + + proxyCommand, err := optsWithMetadata.ToProxyCommand() if err != nil { return fmt.Errorf("failed to generate ProxyCommand: %w", err) } + hostName := opts.SessionIdentifier() + sshArgs := []string{ "-l", userName, "-i", privateKeyPath, @@ -260,11 +504,10 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server if opts.UserKnownHostsFile != "" { sshArgs = append(sshArgs, "-o", "UserKnownHostsFile="+opts.UserKnownHostsFile) } - sshArgs = append(sshArgs, opts.ClusterID) + sshArgs = append(sshArgs, hostName) sshArgs = append(sshArgs, opts.AdditionalArgs...) - cmdio.LogString(ctx, "Launching SSH client: ssh "+strings.Join(sshArgs, " ")) - + log.Debugf(ctx, "Launching SSH client: ssh %s", strings.Join(sshArgs, " ")) sshCmd := exec.CommandContext(ctx, "ssh", sshArgs...) sshCmd.Stdin = os.Stdin @@ -274,9 +517,9 @@ func spawnSSHClient(ctx context.Context, userName, privateKeyPath string, server return sshCmd.Run() } -func runSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, serverPort int, opts ClientOptions) error { +func runSSHProxy(ctx context.Context, client *databricks.WorkspaceClient, serverPort int, clusterID string, opts ClientOptions) error { createConn := func(ctx context.Context, connID string) (*websocket.Conn, error) { - return createWebsocketConnection(ctx, client, connID, opts.ClusterID, serverPort) + return createWebsocketConnection(ctx, client, connID, clusterID, serverPort, opts.Liteswap) } requestHandoverTick := func() <-chan time.Time { return time.After(opts.HandoverTimeout) @@ -304,36 +547,96 @@ func checkClusterState(ctx context.Context, client *databricks.WorkspaceClient, return nil } -func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (string, int, error) { - serverPort, userName, err := getServerMetadata(ctx, client, opts.ClusterID, version) +// waitForJobToStart polls the task status until the SSH server task is in RUNNING state or terminates. +// Returns an error if the task fails to start or if polling times out. +func waitForJobToStart(ctx context.Context, client *databricks.WorkspaceClient, runID int64, taskStartupTimeout time.Duration) error { + cmdio.LogString(ctx, "Waiting for the SSH server task to start...") + var prevState jobs.RunLifecycleStateV2State + + _, err := retries.Poll(ctx, taskStartupTimeout, func() (*jobs.RunTask, *retries.Err) { + run, err := client.Jobs.GetRun(ctx, jobs.GetRunRequest{ + RunId: runID, + }) + if err != nil { + return nil, retries.Halt(fmt.Errorf("failed to get job run status: %w", err)) + } + + // Find the SSH server task + var sshTask *jobs.RunTask + for i := range run.Tasks { + if run.Tasks[i].TaskKey == sshServerTaskKey { + sshTask = &run.Tasks[i] + break + } + } + + if sshTask == nil { + return nil, retries.Halt(fmt.Errorf("SSH server task '%s' not found in job run", sshServerTaskKey)) + } + + if sshTask.Status == nil { + return nil, retries.Halt(errors.New("task status is nil")) + } + + currentState := sshTask.Status.State + + // Print status if it changed + if currentState != prevState { + cmdio.LogString(ctx, fmt.Sprintf("Task status: %s", currentState)) + prevState = currentState + } + + // Check if task is running + if currentState == jobs.RunLifecycleStateV2StateRunning { + cmdio.LogString(ctx, "SSH server task is now running, proceeding to connect...") + return sshTask, nil + } + + // Check for terminal failure states + if currentState == jobs.RunLifecycleStateV2StateTerminated { + return nil, retries.Halt(errors.New("task terminated before reaching running state")) + } + + // Continue polling for other states + return nil, retries.Continues(fmt.Sprintf("waiting for task to start (current state: %s)", currentState)) + }) + + return err +} + +func ensureSSHServerIsRunning(ctx context.Context, client *databricks.WorkspaceClient, version, secretScopeName string, opts ClientOptions) (string, int, string, error) { + sessionID := opts.SessionIdentifier() + // For dedicated clusters, use clusterID; for serverless, it will be read from metadata + clusterID := opts.ClusterID + + serverPort, userName, effectiveClusterID, err := getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap) if errors.Is(err, errServerMetadata) { cmdio.LogString(ctx, "SSH server is not running, starting it now...") - runID, err := submitSSHTunnelJob(ctx, client, version, secretScopeName, opts) + err := submitSSHTunnelJob(ctx, client, version, secretScopeName, opts) if err != nil { - return "", 0, fmt.Errorf("failed to submit ssh server job: %w", err) + return "", 0, "", fmt.Errorf("failed to submit and start ssh server job: %w", err) } - cmdio.LogString(ctx, fmt.Sprintf("Job submitted successfully with run ID: %d", runID)) cmdio.LogString(ctx, "Waiting for the ssh server to start...") maxRetries := 30 for retries := range maxRetries { if ctx.Err() != nil { - return "", 0, ctx.Err() + return "", 0, "", ctx.Err() } - serverPort, userName, err = getServerMetadata(ctx, client, opts.ClusterID, version) + serverPort, userName, effectiveClusterID, err = getServerMetadata(ctx, client, sessionID, clusterID, version, opts.Liteswap) if err == nil { cmdio.LogString(ctx, "Health check successful, starting ssh WebSocket connection...") break } else if retries < maxRetries-1 { time.Sleep(2 * time.Second) } else { - return "", 0, fmt.Errorf("failed to start the ssh server: %w", err) + return "", 0, "", fmt.Errorf("failed to start the ssh server: %w", err) } } } else if err != nil { - return "", 0, err + return "", 0, "", err } - return userName, serverPort, nil + return userName, serverPort, effectiveClusterID, nil } diff --git a/experimental/ssh/internal/client/ssh-server-bootstrap.py b/experimental/ssh/internal/client/ssh-server-bootstrap.py index 8b8170bf42..8dc0aff16a 100644 --- a/experimental/ssh/internal/client/ssh-server-bootstrap.py +++ b/experimental/ssh/internal/client/ssh-server-bootstrap.py @@ -17,6 +17,7 @@ dbutils.widgets.text("authorizedKeySecretName", "") dbutils.widgets.text("maxClients", "10") dbutils.widgets.text("shutdownDelay", "10m") +dbutils.widgets.text("sessionId", "") # Required: unique identifier for the session def cleanup(): @@ -111,6 +112,9 @@ def run_ssh_server(): shutdown_delay = dbutils.widgets.get("shutdownDelay") max_clients = dbutils.widgets.get("maxClients") + session_id = dbutils.widgets.get("sessionId") + if not session_id: + raise RuntimeError("Session ID is required. Please provide it using the 'sessionId' widget.") arch = platform.machine() if arch == "x86_64": @@ -127,29 +131,29 @@ def run_ssh_server(): binary_path = f"/Workspace/Users/{user_name}/.databricks/ssh-tunnel/{version}/{cli_name}/databricks" + server_args = [ + binary_path, + "ssh", + "server", + f"--cluster={ctx.clusterId}", + f"--session-id={session_id}", + f"--secret-scope-name={secrets_scope}", + f"--authorized-key-secret-name={public_key_secret_name}", + f"--max-clients={max_clients}", + f"--shutdown-delay={shutdown_delay}", + f"--version={version}", + # "info" has enough verbosity for debugging purposes, and "debug" log level prints too much (including secrets) + "--log-level=info", + "--log-format=json", + # To get the server logs: + # 1. Get a job run id from the "databricks ssh connect" output + # 2. Run "databricks jobs get-run " and open a run_page_url + # TODO: file with log rotation + "--log-file=stdout", + ] + try: - subprocess.run( - [ - binary_path, - "ssh", - "server", - f"--cluster={ctx.clusterId}", - f"--secret-scope-name={secrets_scope}", - f"--authorized-key-secret-name={public_key_secret_name}", - f"--max-clients={max_clients}", - f"--shutdown-delay={shutdown_delay}", - f"--version={version}", - # "info" has enough verbosity for debugging purposes, and "debug" log level prints too much (including secrets) - "--log-level=info", - "--log-format=json", - # To get the server logs: - # 1. Get a job run id from the "databricks ssh connect" output - # 2. Run "databricks jobs get-run " and open a run_page_url - # TODO: file with log rotation - "--log-file=stdout", - ], - check=True, - ) + subprocess.run(server_args, check=True) finally: kill_all_children() diff --git a/experimental/ssh/internal/client/websockets.go b/experimental/ssh/internal/client/websockets.go index b1ab20889f..fba53c891e 100644 --- a/experimental/ssh/internal/client/websockets.go +++ b/experimental/ssh/internal/client/websockets.go @@ -9,7 +9,7 @@ import ( "github.com/gorilla/websocket" ) -func createWebsocketConnection(ctx context.Context, client *databricks.WorkspaceClient, connID, clusterID string, serverPort int) (*websocket.Conn, error) { +func createWebsocketConnection(ctx context.Context, client *databricks.WorkspaceClient, connID, clusterID string, serverPort int, liteswap string) (*websocket.Conn, error) { url, err := getProxyURL(ctx, client, connID, clusterID, serverPort) if err != nil { return nil, fmt.Errorf("failed to get proxy URL: %w", err) @@ -20,6 +20,9 @@ func createWebsocketConnection(ctx context.Context, client *databricks.Workspace return nil, fmt.Errorf("failed to create request: %w", err) } + if liteswap != "" { + req.Header.Set("x-databricks-traffic-id", "testenv://liteswap/"+liteswap) + } if err := client.Config.Authenticate(req); err != nil { return nil, fmt.Errorf("failed to authenticate: %w", err) } diff --git a/experimental/ssh/internal/keys/keys.go b/experimental/ssh/internal/keys/keys.go index a1c279c749..735f4f0f1a 100644 --- a/experimental/ssh/internal/keys/keys.go +++ b/experimental/ssh/internal/keys/keys.go @@ -14,8 +14,9 @@ import ( "golang.org/x/crypto/ssh" ) -// We use different client keys for each cluster as a good practice for better isolation and control. -func GetLocalSSHKeyPath(clusterID, keysDir string) (string, error) { +// We use different client keys for each session as a good practice for better isolation and control. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +func GetLocalSSHKeyPath(sessionID, keysDir string) (string, error) { if keysDir == "" { homeDir, err := os.UserHomeDir() if err != nil { @@ -23,7 +24,7 @@ func GetLocalSSHKeyPath(clusterID, keysDir string) (string, error) { } keysDir = filepath.Join(homeDir, ".databricks", "ssh-tunnel-keys") } - return filepath.Join(keysDir, clusterID), nil + return filepath.Join(keysDir, sessionID), nil } func generateSSHKeyPair() ([]byte, []byte, error) { @@ -68,7 +69,7 @@ func SaveSSHKeyPair(keyPath string, privateKeyBytes, publicKeyBytes []byte) erro return nil } -func CheckAndGenerateSSHKeyPairFromSecrets(ctx context.Context, client *databricks.WorkspaceClient, clusterID, secretScopeName, privateKeyName, publicKeyName string) ([]byte, []byte, error) { +func CheckAndGenerateSSHKeyPairFromSecrets(ctx context.Context, client *databricks.WorkspaceClient, secretScopeName, privateKeyName, publicKeyName string) ([]byte, []byte, error) { privateKeyBytes, err := GetSecret(ctx, client, secretScopeName, privateKeyName) if err != nil { privateKeyBytes, publicKeyBytes, err := generateSSHKeyPair() diff --git a/experimental/ssh/internal/keys/secrets.go b/experimental/ssh/internal/keys/secrets.go index 0a7b2c1266..d4e00d10ba 100644 --- a/experimental/ssh/internal/keys/secrets.go +++ b/experimental/ssh/internal/keys/secrets.go @@ -10,16 +10,31 @@ import ( "github.com/databricks/databricks-sdk-go/service/workspace" ) -func CreateKeysSecretScope(ctx context.Context, client *databricks.WorkspaceClient, clusterID string) (string, error) { +// CreateKeysSecretScope creates or retrieves the secret scope for SSH keys. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +func CreateKeysSecretScope(ctx context.Context, client *databricks.WorkspaceClient, sessionID string) (string, error) { me, err := client.CurrentUser.Me(ctx) if err != nil { return "", fmt.Errorf("failed to get current user: %w", err) } - secretScopeName := fmt.Sprintf("%s-%s-ssh-tunnel-keys", me.UserName, clusterID) + secretScopeName := fmt.Sprintf("%s-%s-ssh-tunnel-keys", me.UserName, sessionID) + + // Do not create the scope if it already exists. + // We can instead filter out "resource already exists" errors from CreateScope, + // but that API can also lead to "limit exceeded" errors, even if the scope does actually exist. + scope, err := client.Secrets.ListSecretsByScope(ctx, secretScopeName) + if err != nil && !errors.Is(err, databricks.ErrResourceDoesNotExist) { + return "", fmt.Errorf("failed to check if secret scope %s exists: %w", secretScopeName, err) + } + + if scope != nil && err == nil { + return secretScopeName, nil + } + err = client.Secrets.CreateScope(ctx, workspace.CreateScope{ Scope: secretScopeName, }) - if err != nil && !errors.Is(err, databricks.ErrResourceAlreadyExists) { + if err != nil { return "", fmt.Errorf("failed to create secrets scope: %w", err) } return secretScopeName, nil @@ -53,8 +68,10 @@ func putSecret(ctx context.Context, client *databricks.WorkspaceClient, scope, k return nil } -func PutSecretInScope(ctx context.Context, client *databricks.WorkspaceClient, clusterID, key, value string) (string, error) { - scopeName, err := CreateKeysSecretScope(ctx, client, clusterID) +// PutSecretInScope creates the secret scope if needed and stores the secret. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +func PutSecretInScope(ctx context.Context, client *databricks.WorkspaceClient, sessionID, key, value string) (string, error) { + scopeName, err := CreateKeysSecretScope(ctx, client, sessionID) if err != nil { return "", err } diff --git a/experimental/ssh/internal/server/jupyter-init.py b/experimental/ssh/internal/server/jupyter-init.py index 3e58b2f94a..0cd463a581 100644 --- a/experimental/ssh/internal/server/jupyter-init.py +++ b/experimental/ssh/internal/server/jupyter-init.py @@ -186,6 +186,46 @@ def df_html(df: DataFrame) -> str: html_formatter.for_type(DataFrame, df_html) +@_log_exceptions +def _initialize_spark_connect_session_grpc(): + import os + from dbruntime.spark_connection import get_and_configure_uds_spark + os.environ["SPARK_REMOTE"] = "unix:///databricks/sparkconnect/grpc.sock" + spark = get_and_configure_uds_spark() + globals()["spark"] = spark + + +@_log_exceptions +def _initialize_spark_connect_session_dbconnect(): + import IPython + from databricks.connect import DatabricksSession + user_ns = getattr(IPython.get_ipython(), "user_ns", {}) + existing_session = getattr(user_ns, "spark", None) + if existing_session is not None and _is_spark_connect(existing_session): + return + try: + # Clear the existing local spark session, otherwise DatabricksSession will re-use it. + user_ns["spark"] = None + globals()["spark"] = None + # DatabricksSession will use the existing env vars for the connection. + spark_session = DatabricksSession.builder.getOrCreate() + user_ns["spark"] = spark_session + globals()["spark"] = spark_session + except Exception as e: + user_ns["spark"] = existing_session + globals()["spark"] = existing_session + raise e + + +def _is_spark_connect(session) -> bool: + try: + from pyspark.sql.connect.session import SparkSession as ConnectSparkSession + return isinstance(session, ConnectSparkSession) + except ImportError: + return False + + _register_magics() _register_formatters() _register_runtime_hooks() +_initialize_spark_connect_session_dbconnect() diff --git a/experimental/ssh/internal/server/server.go b/experimental/ssh/internal/server/server.go index 66837cbc72..92fa76050a 100644 --- a/experimental/ssh/internal/server/server.go +++ b/experimental/ssh/internal/server/server.go @@ -29,8 +29,11 @@ type ServerOptions struct { MaxClients int // Delay before shutting down the server when there are no active connections ShutdownDelay time.Duration - // The cluster ID that the client started this server on + // The cluster ID that the client started this server on (required for Driver Proxy connections) ClusterID string + // SessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). + // Used for metadata storage path. Defaults to ClusterID if not set. + SessionID string // The directory to store sshd configuration ConfigDir string // The name of the secrets scope to use for client and server keys @@ -56,7 +59,12 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ServerOpt listenAddr := fmt.Sprintf("0.0.0.0:%d", port) log.Info(ctx, "Starting server on "+listenAddr) - err = workspace.SaveWorkspaceMetadata(ctx, client, opts.Version, opts.ClusterID, port) + // Save metadata including ClusterID (required for Driver Proxy connections in serverless mode) + metadata := &workspace.WorkspaceMetadata{ + Port: port, + ClusterID: opts.ClusterID, + } + err = workspace.SaveWorkspaceMetadata(ctx, client, opts.Version, opts.SessionID, metadata) if err != nil { return fmt.Errorf("failed to save metadata to the workspace: %w", err) } @@ -77,6 +85,10 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ServerOpt connections := proxy.NewConnectionsManager(opts.MaxClients, opts.ShutdownDelay) http.Handle("/ssh", proxy.NewProxyServer(ctx, connections, createServerCommand)) http.HandleFunc("/metadata", serveMetadata) + + http.Handle("/driver-proxy-http/ssh", proxy.NewProxyServer(ctx, connections, createServerCommand)) + http.HandleFunc("/driver-proxy-http/metadata", serveMetadata) + go handleTimeout(ctx, connections.TimedOut, opts.ShutdownDelay) return http.ListenAndServe(listenAddr, nil) diff --git a/experimental/ssh/internal/server/sshd.go b/experimental/ssh/internal/server/sshd.go index 2f45588a33..c8f23d02a5 100644 --- a/experimental/ssh/internal/server/sshd.go +++ b/experimental/ssh/internal/server/sshd.go @@ -36,7 +36,7 @@ func prepareSSHDConfig(ctx context.Context, client *databricks.WorkspaceClient, return "", fmt.Errorf("failed to create SSH directory: %w", err) } - privateKeyBytes, publicKeyBytes, err := keys.CheckAndGenerateSSHKeyPairFromSecrets(ctx, client, opts.ClusterID, opts.SecretScopeName, opts.ServerPrivateKeyName, opts.ServerPublicKeyName) + privateKeyBytes, publicKeyBytes, err := keys.CheckAndGenerateSSHKeyPairFromSecrets(ctx, client, opts.SecretScopeName, opts.ServerPrivateKeyName, opts.ServerPublicKeyName) if err != nil { return "", fmt.Errorf("failed to get SSH key pair from secrets: %w", err) } @@ -52,15 +52,13 @@ func prepareSSHDConfig(ctx context.Context, client *databricks.WorkspaceClient, return "", err } - // Set all available env vars, wrapping values in quotes and escaping quotes inside values + // Set all available env vars, wrapping values in quotes, escaping quotes, and stripping newlines setEnv := "SetEnv" for _, env := range os.Environ() { parts := strings.SplitN(env, "=", 2) - if len(parts) != 2 { - continue + if len(parts) == 2 { + setEnv += " " + parts[0] + "=\"" + escapeEnvValue(parts[1]) + "\"" } - valEscaped := strings.ReplaceAll(parts[1], "\"", "\\\"") - setEnv += " " + parts[0] + "=\"" + valEscaped + "\"" } setEnv += " DATABRICKS_CLI_UPSTREAM=databricks_ssh_tunnel" setEnv += " DATABRICKS_CLI_UPSTREAM_VERSION=" + opts.Version @@ -94,3 +92,13 @@ func prepareSSHDConfig(ctx context.Context, client *databricks.WorkspaceClient, func createSSHDProcess(ctx context.Context, configPath string) *exec.Cmd { return exec.CommandContext(ctx, "/usr/sbin/sshd", "-f", configPath, "-i") } + +// escapeEnvValue escapes a value for use in sshd SetEnv directive. +// It strips newlines and escapes backslashes and quotes. +func escapeEnvValue(val string) string { + val = strings.ReplaceAll(val, "\r", "") + val = strings.ReplaceAll(val, "\n", "") + val = strings.ReplaceAll(val, "\\", "\\\\") + val = strings.ReplaceAll(val, "\"", "\\\"") + return val +} diff --git a/experimental/ssh/internal/server/sshd_test.go b/experimental/ssh/internal/server/sshd_test.go new file mode 100644 index 0000000000..a453d987a0 --- /dev/null +++ b/experimental/ssh/internal/server/sshd_test.go @@ -0,0 +1,73 @@ +package server + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestEscapeEnvValue(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "simple value", + input: "hello", + expected: "hello", + }, + { + name: "value with quotes", + input: `say "hello"`, + expected: `say \"hello\"`, + }, + { + name: "value with newline", + input: "line1\nline2", + expected: "line1line2", + }, + { + name: "value with carriage return", + input: "line1\rline2", + expected: "line1line2", + }, + { + name: "value with CRLF", + input: "line1\r\nline2", + expected: "line1line2", + }, + { + name: "value with quotes and newlines", + input: "say \"hello\"\nworld", + expected: `say \"hello\"world`, + }, + { + name: "empty value", + input: "", + expected: "", + }, + { + name: "only newlines", + input: "\n\r\n", + expected: "", + }, + { + name: "backslashes", + input: `foo\bar\`, + expected: `foo\\bar\\`, + }, + { + name: "backslash before quote", + input: `foo\"bar`, + expected: `foo\\\"bar`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := escapeEnvValue(tt.input) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/experimental/ssh/internal/setup/setup.go b/experimental/ssh/internal/setup/setup.go index 1c60a1e4f7..99b5a68902 100644 --- a/experimental/ssh/internal/setup/setup.go +++ b/experimental/ssh/internal/setup/setup.go @@ -4,14 +4,10 @@ import ( "context" "errors" "fmt" - "os" - "path/filepath" - "regexp" - "strconv" - "strings" "time" "github.com/databricks/cli/experimental/ssh/internal/keys" + "github.com/databricks/cli/experimental/ssh/internal/sshconfig" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/compute" @@ -32,6 +28,8 @@ type SetupOptions struct { SSHKeysDir string // Optional auth profile name. If present, will be added as --profile flag to the ProxyCommand Profile string + // Proxy command to use for the SSH connection + ProxyCommand string } func validateClusterAccess(ctx context.Context, client *databricks.WorkspaceClient, clusterID string) error { @@ -45,126 +43,16 @@ func validateClusterAccess(ctx context.Context, client *databricks.WorkspaceClie return nil } -func resolveConfigPath(configPath string) (string, error) { - if configPath != "" { - return configPath, nil - } - homeDir, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("failed to get home directory: %w", err) - } - return filepath.Join(homeDir, ".ssh", "config"), nil -} - -func GenerateProxyCommand(clusterId string, autoStartCluster bool, shutdownDelay time.Duration, profile, userName string, serverPort int, handoverTimeout time.Duration) (string, error) { - executablePath, err := os.Executable() - if err != nil { - return "", fmt.Errorf("failed to get current executable path: %w", err) - } - - proxyCommand := fmt.Sprintf("%q ssh connect --proxy --cluster=%s --auto-start-cluster=%t --shutdown-delay=%s", - executablePath, clusterId, autoStartCluster, shutdownDelay.String()) - - if userName != "" && serverPort != 0 { - proxyCommand += " --metadata=" + userName + "," + strconv.Itoa(serverPort) - } - - if handoverTimeout > 0 { - proxyCommand += " --handover-timeout=" + handoverTimeout.String() - } - - if profile != "" { - proxyCommand += " --profile=" + profile - } - - return proxyCommand, nil -} - func generateHostConfig(opts SetupOptions) (string, error) { identityFilePath, err := keys.GetLocalSSHKeyPath(opts.ClusterID, opts.SSHKeysDir) if err != nil { return "", fmt.Errorf("failed to get local keys folder: %w", err) } - proxyCommand, err := GenerateProxyCommand(opts.ClusterID, opts.AutoStartCluster, opts.ShutdownDelay, opts.Profile, "", 0, 0) - if err != nil { - return "", fmt.Errorf("failed to generate ProxyCommand: %w", err) - } - - hostConfig := fmt.Sprintf(` -Host %s - User root - ConnectTimeout 360 - StrictHostKeyChecking accept-new - IdentitiesOnly yes - IdentityFile %q - ProxyCommand %s -`, opts.HostName, identityFilePath, proxyCommand) - + hostConfig := sshconfig.GenerateHostConfig(opts.HostName, "root", identityFilePath, opts.ProxyCommand) return hostConfig, nil } -func ensureSSHConfigExists(configPath string) error { - _, err := os.Stat(configPath) - if os.IsNotExist(err) { - sshDir := filepath.Dir(configPath) - err = os.MkdirAll(sshDir, 0o700) - if err != nil { - return fmt.Errorf("failed to create SSH directory: %w", err) - } - err = os.WriteFile(configPath, []byte(""), 0o600) - if err != nil { - return fmt.Errorf("failed to create SSH config file: %w", err) - } - return nil - } else if err != nil { - return fmt.Errorf("failed to check SSH config file: %w", err) - } - return nil -} - -func checkExistingHosts(content []byte, hostName string) (bool, error) { - existingContent := string(content) - pattern := fmt.Sprintf(`(?m)^\s*Host\s+%s\s*$`, regexp.QuoteMeta(hostName)) - matched, err := regexp.MatchString(pattern, existingContent) - if err != nil { - return false, fmt.Errorf("failed to check for existing host: %w", err) - } - if matched { - return true, nil - } - return false, nil -} - -func createBackup(content []byte, configPath string) (string, error) { - backupPath := configPath + ".bak" - err := os.WriteFile(backupPath, content, 0o600) - if err != nil { - return backupPath, fmt.Errorf("failed to create backup of SSH config file: %w", err) - } - return backupPath, nil -} - -func updateSSHConfigFile(configPath, hostConfig, hostName string) error { - content, err := os.ReadFile(configPath) - if err != nil { - return fmt.Errorf("failed to read SSH config file: %w", err) - } - - existingContent := string(content) - if !strings.HasSuffix(existingContent, "\n") && existingContent != "" { - existingContent += "\n" - } - newContent := existingContent + hostConfig - - err = os.WriteFile(configPath, []byte(newContent), 0o600) - if err != nil { - return fmt.Errorf("failed to update SSH config file: %w", err) - } - - return nil -} - func clusterSelectionPrompt(ctx context.Context, client *databricks.WorkspaceClient) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading clusters.") @@ -202,50 +90,51 @@ func Setup(ctx context.Context, client *databricks.WorkspaceClient, opts SetupOp return err } - configPath, err := resolveConfigPath(opts.SSHConfigPath) + configPath, err := sshconfig.GetMainConfigPathOrDefault(opts.SSHConfigPath) if err != nil { return err } - hostConfig, err := generateHostConfig(opts) + err = sshconfig.EnsureIncludeDirective(configPath) if err != nil { return err } - err = ensureSSHConfigExists(configPath) + hostConfig, err := generateHostConfig(opts) if err != nil { return err } - existingContent, err := os.ReadFile(configPath) + exists, err := sshconfig.HostConfigExists(opts.HostName) if err != nil { - return fmt.Errorf("failed to read SSH config file: %w", err) + return err } - if len(existingContent) > 0 { - exists, err := checkExistingHosts(existingContent, opts.HostName) + recreate := false + if exists { + recreate, err = sshconfig.PromptRecreateConfig(ctx, opts.HostName) if err != nil { return err } - if exists { - cmdio.LogString(ctx, fmt.Sprintf("Host '%s' already exists in the SSH config, skipping setup", opts.HostName)) + if !recreate { + cmdio.LogString(ctx, fmt.Sprintf("Skipping setup for host '%s'", opts.HostName)) return nil } - backupPath, err := createBackup(existingContent, configPath) - if err != nil { - return err - } - cmdio.LogString(ctx, "Created backup of existing SSH config at "+backupPath) } cmdio.LogString(ctx, "Adding new entry to the SSH config:\n"+hostConfig) - err = updateSSHConfigFile(configPath, hostConfig, opts.HostName) + _, err = sshconfig.CreateOrUpdateHostConfig(ctx, opts.HostName, hostConfig, recreate) + if err != nil { + return err + } + + hostConfigPath, err := sshconfig.GetHostConfigPath(opts.HostName) if err != nil { return err } - cmdio.LogString(ctx, fmt.Sprintf("Updated SSH config file at %s with '%s' host", configPath, opts.HostName)) + cmdio.LogString(ctx, fmt.Sprintf("Created SSH config file at %s for '%s' host", hostConfigPath, opts.HostName)) cmdio.LogString(ctx, fmt.Sprintf("You can now connect to the cluster using 'ssh %s' terminal command, or use remote capabilities of your IDE", opts.HostName)) return nil } diff --git a/experimental/ssh/internal/setup/setup_test.go b/experimental/ssh/internal/setup/setup_test.go index 7c4cb20925..975828a3c8 100644 --- a/experimental/ssh/internal/setup/setup_test.go +++ b/experimental/ssh/internal/setup/setup_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/databricks/cli/experimental/ssh/internal/client" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/databricks-sdk-go/experimental/mocks" "github.com/databricks/databricks-sdk-go/service/compute" @@ -56,7 +57,12 @@ func TestValidateClusterAccess_ClusterNotFound(t *testing.T) { } func TestGenerateProxyCommand(t *testing.T) { - cmd, err := GenerateProxyCommand("cluster-123", true, 45*time.Second, "", "", 0, 0) + opts := client.ClientOptions{ + ClusterID: "cluster-123", + AutoStartCluster: true, + ShutdownDelay: 45 * time.Second, + } + cmd, err := opts.ToProxyCommand() assert.NoError(t, err) assert.Contains(t, cmd, "ssh connect --proxy --cluster=cluster-123 --auto-start-cluster=true --shutdown-delay=45s") assert.NotContains(t, cmd, "--metadata") @@ -65,7 +71,15 @@ func TestGenerateProxyCommand(t *testing.T) { } func TestGenerateProxyCommand_WithExtraArgs(t *testing.T) { - cmd, err := GenerateProxyCommand("cluster-123", true, 45*time.Second, "test-profile", "user", 2222, 2*time.Minute) + opts := client.ClientOptions{ + ClusterID: "cluster-123", + AutoStartCluster: true, + ShutdownDelay: 45 * time.Second, + Profile: "test-profile", + ServerMetadata: "user,2222", + HandoverTimeout: 2 * time.Minute, + } + cmd, err := opts.ToProxyCommand() assert.NoError(t, err) assert.Contains(t, cmd, "ssh connect --proxy --cluster=cluster-123 --auto-start-cluster=true --shutdown-delay=45s") assert.Contains(t, cmd, " --metadata=user,2222") @@ -73,16 +87,55 @@ func TestGenerateProxyCommand_WithExtraArgs(t *testing.T) { assert.Contains(t, cmd, " --profile=test-profile") } +func TestGenerateProxyCommand_ServerlessMode(t *testing.T) { + opts := client.ClientOptions{ + ConnectionName: "my-connection", + ShutdownDelay: 45 * time.Second, + ServerMetadata: "user,2222,serverless-cluster-id", + } + cmd, err := opts.ToProxyCommand() + assert.NoError(t, err) + assert.Contains(t, cmd, "ssh connect --proxy --name=my-connection --shutdown-delay=45s") + assert.Contains(t, cmd, " --metadata=user,2222,serverless-cluster-id") + assert.NotContains(t, cmd, "--cluster=") + assert.NotContains(t, cmd, "--auto-start-cluster") +} + +func TestGenerateProxyCommand_ServerlessModeWithAccelerator(t *testing.T) { + opts := client.ClientOptions{ + ConnectionName: "my-connection", + ShutdownDelay: 45 * time.Second, + Accelerator: "GPU_1xA10", + ServerMetadata: "user,2222,serverless-cluster-id", + } + cmd, err := opts.ToProxyCommand() + assert.NoError(t, err) + assert.Contains(t, cmd, "ssh connect --proxy --name=my-connection --shutdown-delay=45s") + assert.Contains(t, cmd, " --accelerator=GPU_1xA10") + assert.Contains(t, cmd, " --metadata=user,2222,serverless-cluster-id") + assert.NotContains(t, cmd, "--cluster=") + assert.NotContains(t, cmd, "--auto-start-cluster") +} + func TestGenerateHostConfig_Valid(t *testing.T) { - // Create a temporary directory for testing tmpDir := t.TempDir() + clientOpts := client.ClientOptions{ + ClusterID: "cluster-123", + AutoStartCluster: true, + ShutdownDelay: 30 * time.Second, + Profile: "test-profile", + } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + opts := SetupOptions{ HostName: "test-host", ClusterID: "cluster-123", SSHKeysDir: tmpDir, ShutdownDelay: 30 * time.Second, Profile: "test-profile", + ProxyCommand: proxyCommand, } result, err := generateHostConfig(opts) @@ -95,29 +148,35 @@ func TestGenerateHostConfig_Valid(t *testing.T) { assert.Contains(t, result, "--shutdown-delay=30s") assert.Contains(t, result, "--profile=test-profile") - // Check that identity file path is included expectedKeyPath := filepath.Join(tmpDir, "cluster-123") assert.Contains(t, result, fmt.Sprintf(`IdentityFile %q`, expectedKeyPath)) } func TestGenerateHostConfig_WithoutProfile(t *testing.T) { - // Create a temporary directory for testing tmpDir := t.TempDir() + clientOpts := client.ClientOptions{ + ClusterID: "cluster-123", + AutoStartCluster: true, + ShutdownDelay: 30 * time.Second, + Profile: "", + } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + opts := SetupOptions{ HostName: "test-host", ClusterID: "cluster-123", SSHKeysDir: tmpDir, ShutdownDelay: 30 * time.Second, - Profile: "", // No profile + Profile: "", + ProxyCommand: proxyCommand, } result, err := generateHostConfig(opts) assert.NoError(t, err) - // Should not contain profile option assert.NotContains(t, result, "--profile=") - // But should contain other elements assert.Contains(t, result, "Host test-host") assert.Contains(t, result, "--cluster=cluster-123") } @@ -143,181 +202,12 @@ func TestGenerateHostConfig_PathEscaping(t *testing.T) { assert.Contains(t, result, fmt.Sprintf(`IdentityFile %q`, expectedPath)) } -func TestEnsureSSHConfigExists(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, ".ssh", "config") - - err := ensureSSHConfigExists(configPath) - assert.NoError(t, err) - - // Check that directory was created - _, err = os.Stat(filepath.Dir(configPath)) - assert.NoError(t, err) - - // Check that file was created - _, err = os.Stat(configPath) - assert.NoError(t, err) - - // Check that file is empty - content, err := os.ReadFile(configPath) - assert.NoError(t, err) - assert.Empty(t, content) -} - -func TestCheckExistingHosts_NoExistingHost(t *testing.T) { - content := []byte(`Host other-host - User root - HostName example.com - -Host another-host - User admin -`) - exists, err := checkExistingHosts(content, "test-host") - assert.NoError(t, err) - assert.False(t, exists) -} - -func TestCheckExistingHosts_HostAlreadyExists(t *testing.T) { - content := []byte(`Host test-host - User root - HostName example.com - -Host another-host - User admin -`) - exists, err := checkExistingHosts(content, "another-host") - assert.NoError(t, err) - assert.True(t, exists) -} - -func TestCheckExistingHosts_EmptyContent(t *testing.T) { - content := []byte("") - exists, err := checkExistingHosts(content, "test-host") - assert.NoError(t, err) - assert.False(t, exists) -} - -func TestCheckExistingHosts_HostNameWithWhitespaces(t *testing.T) { - content := []byte(` Host test-host `) - exists, err := checkExistingHosts(content, "test-host") - assert.NoError(t, err) - assert.True(t, exists) -} - -func TestCheckExistingHosts_PartialNameMatch(t *testing.T) { - content := []byte(`Host test-host-long`) - exists, err := checkExistingHosts(content, "test-host") - assert.NoError(t, err) - assert.False(t, exists) -} - -func TestCreateBackup_CreatesBackupSuccessfully(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - content := []byte("original content") - - backupPath, err := createBackup(content, configPath) - assert.NoError(t, err) - assert.Equal(t, configPath+".bak", backupPath) - - // Check that backup file was created with correct content - backupContent, err := os.ReadFile(backupPath) - assert.NoError(t, err) - assert.Equal(t, content, backupContent) -} - -func TestCreateBackup_OverwritesExistingBackup(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - backupPath := configPath + ".bak" - - // Create existing backup - oldContent := []byte("old backup") - err := os.WriteFile(backupPath, oldContent, 0o644) - require.NoError(t, err) - - // Create new backup - newContent := []byte("new content") - resultPath, err := createBackup(newContent, configPath) - assert.NoError(t, err) - assert.Equal(t, backupPath, resultPath) - - // Check that backup was overwritten - backupContent, err := os.ReadFile(backupPath) - assert.NoError(t, err) - assert.Equal(t, newContent, backupContent) -} - -func TestUpdateSSHConfigFile_UpdatesSuccessfully(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - - // Create initial config file - initialContent := "# SSH Config\nHost existing\n User root\n" - err := os.WriteFile(configPath, []byte(initialContent), 0o600) - require.NoError(t, err) - - hostConfig := "\nHost new-host\n User root\n HostName example.com\n" - err = updateSSHConfigFile(configPath, hostConfig, "new-host") - assert.NoError(t, err) - - // Check that content was appended - finalContent, err := os.ReadFile(configPath) - assert.NoError(t, err) - expected := initialContent + hostConfig - assert.Equal(t, expected, string(finalContent)) -} - -func TestUpdateSSHConfigFile_AddsNewlineIfMissing(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - - // Create config file without trailing newline - initialContent := "Host existing\n User root" - err := os.WriteFile(configPath, []byte(initialContent), 0o600) - require.NoError(t, err) - - hostConfig := "\nHost new-host\n User root\n" - err = updateSSHConfigFile(configPath, hostConfig, "new-host") - assert.NoError(t, err) - - // Check that newline was added before the new content - finalContent, err := os.ReadFile(configPath) - assert.NoError(t, err) - expected := initialContent + "\n" + hostConfig - assert.Equal(t, expected, string(finalContent)) -} - -func TestUpdateSSHConfigFile_HandlesEmptyFile(t *testing.T) { - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "config") - - // Create empty config file - err := os.WriteFile(configPath, []byte(""), 0o600) - require.NoError(t, err) - - hostConfig := "Host new-host\n User root\n" - err = updateSSHConfigFile(configPath, hostConfig, "new-host") - assert.NoError(t, err) - - // Check that content was added without extra newlines - finalContent, err := os.ReadFile(configPath) - assert.NoError(t, err) - assert.Equal(t, hostConfig, string(finalContent)) -} - -func TestUpdateSSHConfigFile_HandlesReadError(t *testing.T) { - configPath := "/nonexistent/file" - hostConfig := "Host new-host\n" - - err := updateSSHConfigFile(configPath, hostConfig, "new-host") - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to read SSH config file") -} - func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) { ctx := cmdio.MockDiscard(context.Background()) tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + configPath := filepath.Join(tmpDir, "ssh_config") m := mocks.NewMockWorkspaceClient(t) @@ -336,22 +226,43 @@ func TestSetup_SuccessfulWithNewConfigFile(t *testing.T) { Profile: "test-profile", } - err := Setup(ctx, m.WorkspaceClient, opts) + clientOpts := client.ClientOptions{ + ClusterID: opts.ClusterID, + AutoStartCluster: opts.AutoStartCluster, + ShutdownDelay: opts.ShutdownDelay, + Profile: opts.Profile, + } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + opts.ProxyCommand = proxyCommand + + err = Setup(ctx, m.WorkspaceClient, opts) assert.NoError(t, err) - // Check that config file was created + // Check that main config has Include directive content, err := os.ReadFile(configPath) assert.NoError(t, err) - configStr := string(content) - assert.Contains(t, configStr, "Host test-host") - assert.Contains(t, configStr, "--cluster=cluster-123") - assert.Contains(t, configStr, "--profile=test-profile") + assert.Contains(t, configStr, "Include") + // SSH config uses forward slashes on all platforms + assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") + + // Check that host config file was created + hostConfigPath := filepath.Join(tmpDir, ".databricks", "ssh-tunnel-configs", "test-host") + hostContent, err := os.ReadFile(hostConfigPath) + assert.NoError(t, err) + hostConfigStr := string(hostContent) + assert.Contains(t, hostConfigStr, "Host test-host") + assert.Contains(t, hostConfigStr, "--cluster=cluster-123") + assert.Contains(t, hostConfigStr, "--profile=test-profile") } func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) { ctx := cmdio.MockDiscard(context.Background()) tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + configPath := filepath.Join(tmpDir, "ssh_config") // Create existing config file @@ -374,54 +285,34 @@ func TestSetup_SuccessfulWithExistingConfigFile(t *testing.T) { ShutdownDelay: 60 * time.Second, } - err = Setup(ctx, m.WorkspaceClient, opts) - assert.NoError(t, err) - - // Check that config file was updated and backup was created - content, err := os.ReadFile(configPath) - assert.NoError(t, err) - - configStr := string(content) - assert.Contains(t, configStr, "# Existing SSH Config") // Original content preserved - assert.Contains(t, configStr, "Host new-host") // New content added - assert.Contains(t, configStr, "--cluster=cluster-456") - - // Check backup was created - backupPath := configPath + ".bak" - backupContent, err := os.ReadFile(backupPath) - assert.NoError(t, err) - assert.Equal(t, existingContent, string(backupContent)) -} - -func TestSetup_DoesNotOverrideExistingHost(t *testing.T) { - ctx := cmdio.MockDiscard(context.Background()) - tmpDir := t.TempDir() - configPath := filepath.Join(tmpDir, "ssh_config") - - // Create config file with existing host - existingContent := "Host duplicate-host\n User root\n" - err := os.WriteFile(configPath, []byte(existingContent), 0o600) - require.NoError(t, err) - - m := mocks.NewMockWorkspaceClient(t) - clustersAPI := m.GetMockClustersAPI() - - clustersAPI.EXPECT().Get(ctx, compute.GetClusterRequest{ClusterId: "cluster-123"}).Return(&compute.ClusterDetails{ - DataSecurityMode: compute.DataSecurityModeSingleUser, - }, nil) - - opts := SetupOptions{ - HostName: "duplicate-host", // Same as existing - ClusterID: "cluster-123", - SSHConfigPath: configPath, - SSHKeysDir: tmpDir, - ShutdownDelay: 30 * time.Second, + clientOpts := client.ClientOptions{ + ClusterID: opts.ClusterID, + AutoStartCluster: opts.AutoStartCluster, + ShutdownDelay: opts.ShutdownDelay, + Profile: opts.Profile, } + proxyCommand, err := clientOpts.ToProxyCommand() + require.NoError(t, err) + opts.ProxyCommand = proxyCommand err = Setup(ctx, m.WorkspaceClient, opts) assert.NoError(t, err) + // Check that main config has Include directive and preserves existing content content, err := os.ReadFile(configPath) assert.NoError(t, err) - assert.Equal(t, "Host duplicate-host\n User root\n", string(content)) + configStr := string(content) + assert.Contains(t, configStr, "Include") + // SSH config uses forward slashes on all platforms + assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") + assert.Contains(t, configStr, "# Existing SSH Config") + assert.Contains(t, configStr, "Host existing-host") + + // Check that host config file was created + hostConfigPath := filepath.Join(tmpDir, ".databricks", "ssh-tunnel-configs", "new-host") + hostContent, err := os.ReadFile(hostConfigPath) + assert.NoError(t, err) + hostConfigStr := string(hostContent) + assert.Contains(t, hostConfigStr, "Host new-host") + assert.Contains(t, hostConfigStr, "--cluster=cluster-456") } diff --git a/experimental/ssh/internal/sshconfig/sshconfig.go b/experimental/ssh/internal/sshconfig/sshconfig.go new file mode 100644 index 0000000000..3a6713acbf --- /dev/null +++ b/experimental/ssh/internal/sshconfig/sshconfig.go @@ -0,0 +1,172 @@ +package sshconfig + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/databricks/cli/libs/cmdio" +) + +const ( + // configDirName is the directory name for Databricks SSH tunnel configs, relative to the user's home directory. + configDirName = ".databricks/ssh-tunnel-configs" +) + +func GetConfigDir() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get home directory: %w", err) + } + return filepath.Join(homeDir, configDirName), nil +} + +func GetMainConfigPath() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get home directory: %w", err) + } + return filepath.Join(homeDir, ".ssh", "config"), nil +} + +func GetMainConfigPathOrDefault(configPath string) (string, error) { + if configPath != "" { + return configPath, nil + } + return GetMainConfigPath() +} + +func EnsureMainConfigExists(configPath string) error { + _, err := os.Stat(configPath) + if os.IsNotExist(err) { + sshDir := filepath.Dir(configPath) + err = os.MkdirAll(sshDir, 0o700) + if err != nil { + return fmt.Errorf("failed to create SSH directory: %w", err) + } + err = os.WriteFile(configPath, []byte(""), 0o600) + if err != nil { + return fmt.Errorf("failed to create SSH config file: %w", err) + } + return nil + } + return err +} + +func EnsureIncludeDirective(configPath string) error { + configDir, err := GetConfigDir() + if err != nil { + return err + } + + err = os.MkdirAll(configDir, 0o700) + if err != nil { + return fmt.Errorf("failed to create Databricks SSH config directory: %w", err) + } + + err = EnsureMainConfigExists(configPath) + if err != nil { + return err + } + + content, err := os.ReadFile(configPath) + if err != nil { + return fmt.Errorf("failed to read SSH config file: %w", err) + } + + // Convert path to forward slashes for SSH config compatibility across platforms + configDirUnix := filepath.ToSlash(configDir) + + includeLine := fmt.Sprintf("Include %s/*", configDirUnix) + if strings.Contains(string(content), includeLine) { + return nil + } + + newContent := includeLine + "\n" + if len(content) > 0 && !strings.HasPrefix(string(content), "\n") { + newContent += "\n" + } + newContent += string(content) + + err = os.WriteFile(configPath, []byte(newContent), 0o600) + if err != nil { + return fmt.Errorf("failed to update SSH config file with Include directive: %w", err) + } + + return nil +} + +func GetHostConfigPath(hostName string) (string, error) { + configDir, err := GetConfigDir() + if err != nil { + return "", err + } + return filepath.Join(configDir, hostName), nil +} + +func HostConfigExists(hostName string) (bool, error) { + configPath, err := GetHostConfigPath(hostName) + if err != nil { + return false, err + } + _, err = os.Stat(configPath) + if os.IsNotExist(err) { + return false, nil + } + if err != nil { + return false, fmt.Errorf("failed to check host config file: %w", err) + } + return true, nil +} + +// Returns true if the config was created/updated, false if it was skipped. +func CreateOrUpdateHostConfig(ctx context.Context, hostName, hostConfig string, recreate bool) (bool, error) { + configPath, err := GetHostConfigPath(hostName) + if err != nil { + return false, err + } + + exists, err := HostConfigExists(hostName) + if err != nil { + return false, err + } + + if exists && !recreate { + return false, nil + } + + configDir := filepath.Dir(configPath) + err = os.MkdirAll(configDir, 0o700) + if err != nil { + return false, fmt.Errorf("failed to create config directory: %w", err) + } + + err = os.WriteFile(configPath, []byte(hostConfig), 0o600) + if err != nil { + return false, fmt.Errorf("failed to write host config file: %w", err) + } + + return true, nil +} + +func PromptRecreateConfig(ctx context.Context, hostName string) (bool, error) { + response, err := cmdio.AskYesOrNo(ctx, fmt.Sprintf("Host '%s' already exists. Do you want to recreate the config?", hostName)) + if err != nil { + return false, err + } + return response, nil +} + +func GenerateHostConfig(hostName, userName, identityFile, proxyCommand string) string { + return fmt.Sprintf(` +Host %s + User %s + ConnectTimeout 360 + StrictHostKeyChecking accept-new + IdentitiesOnly yes + IdentityFile %q + ProxyCommand %s +`, hostName, userName, identityFile, proxyCommand) +} diff --git a/experimental/ssh/internal/sshconfig/sshconfig_test.go b/experimental/ssh/internal/sshconfig/sshconfig_test.go new file mode 100644 index 0000000000..5fa13923ee --- /dev/null +++ b/experimental/ssh/internal/sshconfig/sshconfig_test.go @@ -0,0 +1,223 @@ +package sshconfig + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/databricks/cli/libs/cmdio" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetConfigDir(t *testing.T) { + dir, err := GetConfigDir() + assert.NoError(t, err) + assert.Contains(t, dir, filepath.Join(".databricks", "ssh-tunnel-configs")) +} + +func TestGetMainConfigPath(t *testing.T) { + path, err := GetMainConfigPath() + assert.NoError(t, err) + assert.Contains(t, path, filepath.Join(".ssh", "config")) +} + +func TestGetMainConfigPathOrDefault(t *testing.T) { + path, err := GetMainConfigPathOrDefault("/custom/path") + assert.NoError(t, err) + assert.Equal(t, "/custom/path", path) + + path, err = GetMainConfigPathOrDefault("") + assert.NoError(t, err) + assert.Contains(t, path, filepath.Join(".ssh", "config")) +} + +func TestEnsureMainConfigExists(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".ssh", "config") + + err := EnsureMainConfigExists(configPath) + assert.NoError(t, err) + + _, err = os.Stat(filepath.Dir(configPath)) + assert.NoError(t, err) + + _, err = os.Stat(configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Empty(t, content) +} + +func TestEnsureIncludeDirective_NewConfig(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".ssh", "config") + + // Set home directory for test + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + err := EnsureIncludeDirective(configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + + configStr := string(content) + assert.Contains(t, configStr, "Include") + // SSH config uses forward slashes on all platforms + assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") +} + +func TestEnsureIncludeDirective_AlreadyExists(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configPath := filepath.Join(tmpDir, ".ssh", "config") + + configDir, err := GetConfigDir() + require.NoError(t, err) + + // Use forward slashes as that's what SSH config uses + configDirUnix := filepath.ToSlash(configDir) + existingContent := "Include " + configDirUnix + "/*\n\nHost example\n User test\n" + err = os.MkdirAll(filepath.Dir(configPath), 0o700) + require.NoError(t, err) + err = os.WriteFile(configPath, []byte(existingContent), 0o600) + require.NoError(t, err) + + err = EnsureIncludeDirective(configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Equal(t, existingContent, string(content)) +} + +func TestEnsureIncludeDirective_PrependsToExisting(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, ".ssh", "config") + + // Set home directory for test + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + existingContent := "Host example\n User test\n" + err := os.MkdirAll(filepath.Dir(configPath), 0o700) + require.NoError(t, err) + err = os.WriteFile(configPath, []byte(existingContent), 0o600) + require.NoError(t, err) + + err = EnsureIncludeDirective(configPath) + assert.NoError(t, err) + + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + + configStr := string(content) + assert.Contains(t, configStr, "Include") + // SSH config uses forward slashes on all platforms + assert.Contains(t, configStr, ".databricks/ssh-tunnel-configs/*") + assert.Contains(t, configStr, "Host example") + + includeIndex := len("Include") + hostIndex := len(configStr) - len(existingContent) + assert.Less(t, includeIndex, hostIndex, "Include directive should come before existing content") +} + +func TestGetHostConfigPath(t *testing.T) { + path, err := GetHostConfigPath("test-host") + assert.NoError(t, err) + assert.Contains(t, path, filepath.Join(".databricks", "ssh-tunnel-configs", "test-host")) +} + +func TestHostConfigExists(t *testing.T) { + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + exists, err := HostConfigExists("nonexistent") + assert.NoError(t, err) + assert.False(t, exists) + + configDir := filepath.Join(tmpDir, configDirName) + err = os.MkdirAll(configDir, 0o700) + require.NoError(t, err) + err = os.WriteFile(filepath.Join(configDir, "existing-host"), []byte("config"), 0o600) + require.NoError(t, err) + + exists, err = HostConfigExists("existing-host") + assert.NoError(t, err) + assert.True(t, exists) +} + +func TestCreateOrUpdateHostConfig_NewConfig(t *testing.T) { + ctx := cmdio.MockDiscard(context.Background()) + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + hostConfig := "Host test\n User root\n" + created, err := CreateOrUpdateHostConfig(ctx, "test-host", hostConfig, false) + assert.NoError(t, err) + assert.True(t, created) + + configPath, err := GetHostConfigPath("test-host") + require.NoError(t, err) + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Equal(t, hostConfig, string(content)) +} + +func TestCreateOrUpdateHostConfig_ExistingConfigNoRecreate(t *testing.T) { + ctx := cmdio.MockDiscard(context.Background()) + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configDir := filepath.Join(tmpDir, configDirName) + err := os.MkdirAll(configDir, 0o700) + require.NoError(t, err) + existingConfig := "Host test\n User admin\n" + err = os.WriteFile(filepath.Join(configDir, "test-host"), []byte(existingConfig), 0o600) + require.NoError(t, err) + + newConfig := "Host test\n User root\n" + created, err := CreateOrUpdateHostConfig(ctx, "test-host", newConfig, false) + assert.NoError(t, err) + assert.False(t, created) + + configPath, err := GetHostConfigPath("test-host") + require.NoError(t, err) + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Equal(t, existingConfig, string(content)) +} + +func TestCreateOrUpdateHostConfig_ExistingConfigWithRecreate(t *testing.T) { + ctx := cmdio.MockDiscard(context.Background()) + tmpDir := t.TempDir() + t.Setenv("HOME", tmpDir) + t.Setenv("USERPROFILE", tmpDir) + + configDir := filepath.Join(tmpDir, configDirName) + err := os.MkdirAll(configDir, 0o700) + require.NoError(t, err) + existingConfig := "Host test\n User admin\n" + err = os.WriteFile(filepath.Join(configDir, "test-host"), []byte(existingConfig), 0o600) + require.NoError(t, err) + + newConfig := "Host test\n User root\n" + created, err := CreateOrUpdateHostConfig(ctx, "test-host", newConfig, true) + assert.NoError(t, err) + assert.True(t, created) + + configPath, err := GetHostConfigPath("test-host") + require.NoError(t, err) + content, err := os.ReadFile(configPath) + assert.NoError(t, err) + assert.Equal(t, newConfig, string(content)) +} diff --git a/experimental/ssh/internal/workspace/workspace.go b/experimental/ssh/internal/workspace/workspace.go index 10e593951a..2f017cbae1 100644 --- a/experimental/ssh/internal/workspace/workspace.go +++ b/experimental/ssh/internal/workspace/workspace.go @@ -16,6 +16,8 @@ const metadataFileName = "metadata.json" type WorkspaceMetadata struct { Port int `json:"port"` + // ClusterID is required for Driver Proxy websocket connections (for any compute type, including serverless) + ClusterID string `json:"cluster_id,omitempty"` } func getWorkspaceRootDir(ctx context.Context, client *databricks.WorkspaceClient) (string, error) { @@ -34,49 +36,55 @@ func GetWorkspaceVersionedDir(ctx context.Context, client *databricks.WorkspaceC return filepath.ToSlash(filepath.Join(contentDir, version)), nil } -func GetWorkspaceContentDir(ctx context.Context, client *databricks.WorkspaceClient, version, clusterID string) (string, error) { +// GetWorkspaceContentDir returns the directory for storing session content. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +func GetWorkspaceContentDir(ctx context.Context, client *databricks.WorkspaceClient, version, sessionID string) (string, error) { contentDir, err := GetWorkspaceVersionedDir(ctx, client, version) if err != nil { return "", fmt.Errorf("failed to get versioned workspace directory: %w", err) } - return filepath.ToSlash(filepath.Join(contentDir, clusterID)), nil + return filepath.ToSlash(filepath.Join(contentDir, sessionID)), nil } -func GetWorkspaceMetadata(ctx context.Context, client *databricks.WorkspaceClient, version, clusterID string) (int, error) { - contentDir, err := GetWorkspaceContentDir(ctx, client, version, clusterID) +// GetWorkspaceMetadata loads session metadata from the workspace. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +func GetWorkspaceMetadata(ctx context.Context, client *databricks.WorkspaceClient, version, sessionID string) (*WorkspaceMetadata, error) { + contentDir, err := GetWorkspaceContentDir(ctx, client, version, sessionID) if err != nil { - return 0, fmt.Errorf("failed to get workspace content directory: %w", err) + return nil, fmt.Errorf("failed to get workspace content directory: %w", err) } metadataPath := filepath.ToSlash(filepath.Join(contentDir, metadataFileName)) content, err := client.Workspace.Download(ctx, metadataPath) if err != nil { - return 0, fmt.Errorf("failed to download metadata file: %w", err) + return nil, fmt.Errorf("failed to download metadata file: %w", err) } defer content.Close() metadataBytes, err := io.ReadAll(content) if err != nil { - return 0, fmt.Errorf("failed to read metadata content: %w", err) + return nil, fmt.Errorf("failed to read metadata content: %w", err) } var metadata WorkspaceMetadata err = json.Unmarshal(metadataBytes, &metadata) if err != nil { - return 0, fmt.Errorf("failed to parse metadata JSON: %w", err) + return nil, fmt.Errorf("failed to parse metadata JSON: %w", err) } - return metadata.Port, nil + return &metadata, nil } -func SaveWorkspaceMetadata(ctx context.Context, client *databricks.WorkspaceClient, version, clusterID string, port int) error { - metadataBytes, err := json.Marshal(WorkspaceMetadata{Port: port}) +// SaveWorkspaceMetadata saves session metadata to the workspace. +// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless). +func SaveWorkspaceMetadata(ctx context.Context, client *databricks.WorkspaceClient, version, sessionID string, metadata *WorkspaceMetadata) error { + metadataBytes, err := json.Marshal(metadata) if err != nil { return fmt.Errorf("failed to marshal metadata: %w", err) } - contentDir, err := GetWorkspaceContentDir(ctx, client, version, clusterID) + contentDir, err := GetWorkspaceContentDir(ctx, client, version, sessionID) if err != nil { return fmt.Errorf("failed to get workspace content directory: %w", err) }