Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions experimental/ssh/cmd/connect.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ssh

import (
"errors"
"time"

"github.com/databricks/cli/cmd/root"
Expand All @@ -22,21 +23,31 @@ the SSH server and handling the connection proxy.
}

var clusterID string
var connectionName string
var accelerator string
var proxyMode bool
var ide string
var serverMetadata string
var shutdownDelay time.Duration
var maxClients int
var handoverTimeout time.Duration
var releasesDir string
var autoStartCluster bool
var userKnownHostsFile string
var liteswap string

cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (required)")
cmd.MarkFlagRequired("cluster")
cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)")
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects")
cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients")
cmd.Flags().BoolVar(&autoStartCluster, "auto-start-cluster", true, "Automatically start the cluster if it is not running")

cmd.Flags().StringVar(&connectionName, "name", "", "Connection name (for serverless compute)")
cmd.Flags().MarkHidden("name")
cmd.Flags().StringVar(&accelerator, "accelerator", "", "GPU accelerator type (GPU_1xA10 or GPU_8xH100)")
cmd.Flags().MarkHidden("accelerator")
cmd.Flags().StringVar(&ide, "ide", "", "Open remote IDE window (vscode or cursor)")
cmd.Flags().MarkHidden("ide")

cmd.Flags().BoolVar(&proxyMode, "proxy", false, "ProxyCommand mode")
cmd.Flags().MarkHidden("proxy")
cmd.Flags().StringVar(&serverMetadata, "metadata", "", "Metadata of the running SSH server (format: <user_name>,<port>)")
Expand All @@ -50,6 +61,9 @@ the SSH server and handling the connection proxy.
cmd.Flags().StringVar(&userKnownHostsFile, "user-known-hosts-file", "", "Path to user known hosts file for SSH client")
cmd.Flags().MarkHidden("user-known-hosts-file")

cmd.Flags().StringVar(&liteswap, "liteswap", "", "Liteswap header value for traffic routing (dev/test only)")
cmd.Flags().MarkHidden("liteswap")

cmd.PreRunE = func(cmd *cobra.Command, args []string) error {
// CLI in the proxy mode is executed by the ssh client and can't prompt for input
if proxyMode {
Expand All @@ -64,20 +78,41 @@ the SSH server and handling the connection proxy.
cmd.RunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
wsClient := cmdctx.WorkspaceClient(ctx)

if !proxyMode && clusterID == "" && connectionName == "" {
return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)")
}

if accelerator != "" && connectionName == "" {
return errors.New("--accelerator flag can only be used with serverless compute (--name flag)")
}

// Remove when we add support for serverless CPU
if connectionName != "" && accelerator == "" {
return errors.New("--name flag requires --accelerator to be set (for now we only support serverless GPU compute)")
}

// TODO: validate connectionName if provided

opts := client.ClientOptions{
Profile: wsClient.Config.Profile,
ClusterID: clusterID,
ConnectionName: connectionName,
Accelerator: accelerator,
ProxyMode: proxyMode,
IDE: ide,
ServerMetadata: serverMetadata,
ShutdownDelay: shutdownDelay,
MaxClients: maxClients,
HandoverTimeout: handoverTimeout,
ReleasesDir: releasesDir,
ServerTimeout: max(serverTimeout, shutdownDelay),
TaskStartupTimeout: taskStartupTimeout,
AutoStartCluster: autoStartCluster,
ClientPublicKeyName: clientPublicKeyName,
ClientPrivateKeyName: clientPrivateKeyName,
UserKnownHostsFile: userKnownHostsFile,
Liteswap: liteswap,
AdditionalArgs: args,
}
return client.Run(ctx, wsClient, opts)
Expand Down
1 change: 1 addition & 0 deletions experimental/ssh/cmd/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ const (
defaultHandoverTimeout = 30 * time.Minute

serverTimeout = 24 * time.Hour
taskStartupTimeout = 10 * time.Minute
serverPortRange = 100
serverConfigDir = ".ssh-tunnel"
serverPrivateKeyName = "server-private-key"
Expand Down
4 changes: 4 additions & 0 deletions experimental/ssh/cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,15 @@ and proxies them to local SSH daemon processes.
var maxClients int
var shutdownDelay time.Duration
var clusterID string
var sessionID string
var version string
var secretScopeName string
var authorizedKeySecretName string

cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID")
cmd.MarkFlagRequired("cluster")
cmd.Flags().StringVar(&sessionID, "session-id", "", "Session identifier (cluster ID or serverless connection name)")
cmd.MarkFlagRequired("session-id")
cmd.Flags().StringVar(&secretScopeName, "secret-scope-name", "", "Databricks secret scope name to store SSH keys")
cmd.MarkFlagRequired("secret-scope-name")
cmd.Flags().StringVar(&authorizedKeySecretName, "authorized-key-secret-name", "", "Name of the secret containing the client public key")
Expand All @@ -56,6 +59,7 @@ and proxies them to local SSH daemon processes.
wsc := cmdctx.WorkspaceClient(ctx)
opts := server.ServerOptions{
ClusterID: clusterID,
SessionID: sessionID,
MaxClients: maxClients,
ShutdownDelay: shutdownDelay,
Version: version,
Expand Down
21 changes: 17 additions & 4 deletions experimental/ssh/cmd/setup.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package ssh

import (
"fmt"
"time"

"github.com/databricks/cli/cmd/root"
"github.com/databricks/cli/experimental/ssh/internal/client"
"github.com/databricks/cli/experimental/ssh/internal/setup"
"github.com/databricks/cli/libs/cmdctx"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -43,16 +45,27 @@ an SSH host configuration to your SSH config file.

cmd.RunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
client := cmdctx.WorkspaceClient(ctx)
opts := setup.SetupOptions{
wsClient := cmdctx.WorkspaceClient(ctx)
setupOpts := setup.SetupOptions{
HostName: hostName,
ClusterID: clusterID,
AutoStartCluster: autoStartCluster,
SSHConfigPath: sshConfigPath,
ShutdownDelay: shutdownDelay,
Profile: client.Config.Profile,
Profile: wsClient.Config.Profile,
}
return setup.Setup(ctx, client, opts)
clientOpts := client.ClientOptions{
ClusterID: setupOpts.ClusterID,
AutoStartCluster: setupOpts.AutoStartCluster,
ShutdownDelay: setupOpts.ShutdownDelay,
Profile: setupOpts.Profile,
}
proxyCommand, err := clientOpts.ToProxyCommand()
if err != nil {
return fmt.Errorf("failed to generate ProxyCommand: %w", err)
}
setupOpts.ProxyCommand = proxyCommand
return setup.Setup(ctx, wsClient, setupOpts)
}

return cmd
Expand Down
Loading