diff --git a/experimental/ssh/cmd/connect.go b/experimental/ssh/cmd/connect.go index c05430368a..4b1de61f3c 100644 --- a/experimental/ssh/cmd/connect.go +++ b/experimental/ssh/cmd/connect.go @@ -1,7 +1,6 @@ package ssh import ( - "errors" "time" "github.com/databricks/cli/cmd/root" @@ -82,22 +81,6 @@ the SSH server and handling the connection proxy. cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() wsClient := cmdctx.WorkspaceClient(ctx) - - if !proxyMode && clusterID == "" && connectionName == "" { - return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)") - } - - if accelerator != "" && connectionName == "" { - return errors.New("--accelerator flag can only be used with serverless compute (--name flag)") - } - - // Remove when we add support for serverless CPU - if connectionName != "" && accelerator == "" { - return errors.New("--name flag requires --accelerator to be set (for now we only support serverless GPU compute)") - } - - // TODO: validate connectionName if provided - opts := client.ClientOptions{ Profile: wsClient.Config.Profile, ClusterID: clusterID, @@ -120,6 +103,9 @@ the SSH server and handling the connection proxy. SkipSettingsCheck: skipSettingsCheck, AdditionalArgs: args, } + if err := opts.Validate(); err != nil { + return err + } return client.Run(ctx, wsClient, opts) } diff --git a/experimental/ssh/internal/client/client.go b/experimental/ssh/internal/client/client.go index 4998fe4dea..d46f602955 100644 --- a/experimental/ssh/internal/client/client.go +++ b/experimental/ssh/internal/client/client.go @@ -12,6 +12,7 @@ import ( "os/exec" "os/signal" "path/filepath" + "regexp" "strconv" "strings" "syscall" @@ -38,6 +39,8 @@ var sshServerBootstrapScript string var errServerMetadata = errors.New("server metadata error") +var connectionNameRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`) + const ( sshServerTaskKey = "start_ssh_server" serverlessEnvironmentKey = "ssh_tunnel_serverless" @@ -97,6 +100,26 @@ type ClientOptions struct { SkipSettingsCheck bool } +func (o *ClientOptions) Validate() error { + if !o.ProxyMode && o.ClusterID == "" && o.ConnectionName == "" { + return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)") + } + if o.Accelerator != "" && o.ConnectionName == "" { + return errors.New("--accelerator flag can only be used with serverless compute (--name flag)") + } + // TODO: Remove when we add support for serverless CPU + if o.ConnectionName != "" && o.Accelerator == "" { + return errors.New("--name flag requires --accelerator to be set (for now we only support serverless GPU compute)") + } + if o.ConnectionName != "" && !connectionNameRegex.MatchString(o.ConnectionName) { + return fmt.Errorf("connection name %q must consist of letters, numbers, dashes, and underscores", o.ConnectionName) + } + if o.IDE != "" && o.IDE != VSCodeOption && o.IDE != CursorOption { + return fmt.Errorf("invalid IDE value: %q, expected %q or %q", o.IDE, VSCodeOption, CursorOption) + } + return nil +} + func (o *ClientOptions) IsServerlessMode() bool { return o.ClusterID == "" && o.ConnectionName != "" } @@ -287,10 +310,6 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt } func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error { - if opts.IDE != VSCodeOption && opts.IDE != CursorOption { - return fmt.Errorf("invalid IDE value: %s, expected '%s' or '%s'", opts.IDE, VSCodeOption, CursorOption) - } - connectionName := opts.SessionIdentifier() if connectionName == "" { return errors.New("connection name is required for IDE integration") diff --git a/experimental/ssh/internal/client/client_test.go b/experimental/ssh/internal/client/client_test.go new file mode 100644 index 0000000000..a727c99f00 --- /dev/null +++ b/experimental/ssh/internal/client/client_test.go @@ -0,0 +1,152 @@ +package client_test + +import ( + "fmt" + "os" + "testing" + "time" + + "github.com/databricks/cli/experimental/ssh/internal/client" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestValidate(t *testing.T) { + tests := []struct { + name string + opts client.ClientOptions + wantErr string + }{ + { + name: "no cluster or connection name", + opts: client.ClientOptions{}, + wantErr: "please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)", + }, + { + name: "proxy mode skips cluster/name check", + opts: client.ClientOptions{ProxyMode: true}, + }, + { + name: "cluster ID only", + opts: client.ClientOptions{ClusterID: "abc-123"}, + }, + { + name: "accelerator without connection name", + opts: client.ClientOptions{ClusterID: "abc-123", Accelerator: "GPU_1xA10"}, + wantErr: "--accelerator flag can only be used with serverless compute (--name flag)", + }, + { + name: "connection name without accelerator", + opts: client.ClientOptions{ConnectionName: "my-conn"}, + wantErr: "--name flag requires --accelerator to be set (for now we only support serverless GPU compute)", + }, + { + name: "invalid connection name characters", + opts: client.ClientOptions{ConnectionName: "my conn!", Accelerator: "GPU_1xA10"}, + wantErr: `connection name "my conn!" must consist of letters, numbers, dashes, and underscores`, + }, + { + name: "connection name with leading dash", + opts: client.ClientOptions{ConnectionName: "-my-conn", Accelerator: "GPU_1xA10"}, + wantErr: `connection name "-my-conn" must consist of letters, numbers, dashes, and underscores`, + }, + { + name: "valid connection name with accelerator", + opts: client.ClientOptions{ConnectionName: "my-conn_1", Accelerator: "GPU_1xA10"}, + }, + { + name: "both cluster ID and connection name", + opts: client.ClientOptions{ClusterID: "abc-123", ConnectionName: "my-conn", Accelerator: "GPU_1xA10"}, + }, + { + name: "proxy mode with invalid connection name", + opts: client.ClientOptions{ProxyMode: true, ConnectionName: "bad name!", Accelerator: "GPU_1xA10"}, + wantErr: `connection name "bad name!" must consist of letters, numbers, dashes, and underscores`, + }, + { + name: "invalid IDE value", + opts: client.ClientOptions{ClusterID: "abc-123", IDE: "vim"}, + wantErr: `invalid IDE value: "vim", expected "vscode" or "cursor"`, + }, + { + name: "valid IDE vscode", + opts: client.ClientOptions{ClusterID: "abc-123", IDE: "vscode"}, + }, + { + name: "valid IDE cursor", + opts: client.ClientOptions{ClusterID: "abc-123", IDE: "cursor"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.opts.Validate() + if tt.wantErr == "" { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, tt.wantErr) + } + }) + } +} + +func TestToProxyCommand(t *testing.T) { + exe, err := os.Executable() + require.NoError(t, err) + quoted := fmt.Sprintf("%q", exe) + + tests := []struct { + name string + opts client.ClientOptions + want string + }{ + { + name: "dedicated cluster", + opts: client.ClientOptions{ClusterID: "abc-123", ShutdownDelay: 5 * time.Minute}, + want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=5m0s", + }, + { + name: "dedicated cluster with auto-start", + opts: client.ClientOptions{ClusterID: "abc-123", AutoStartCluster: true, ShutdownDelay: 5 * time.Minute}, + want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=true --shutdown-delay=5m0s", + }, + { + name: "serverless", + opts: client.ClientOptions{ConnectionName: "my-conn", ShutdownDelay: 2 * time.Minute}, + want: quoted + " ssh connect --proxy --name=my-conn --shutdown-delay=2m0s", + }, + { + name: "serverless with accelerator", + opts: client.ClientOptions{ConnectionName: "my-conn", Accelerator: "GPU_1xA10", ShutdownDelay: 2 * time.Minute}, + want: quoted + " ssh connect --proxy --name=my-conn --shutdown-delay=2m0s --accelerator=GPU_1xA10", + }, + { + name: "with metadata", + opts: client.ClientOptions{ClusterID: "abc-123", ServerMetadata: "user,2222,abc-123"}, + want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=0s --metadata=user,2222,abc-123", + }, + { + name: "with handover timeout", + opts: client.ClientOptions{ClusterID: "abc-123", HandoverTimeout: 10 * time.Minute}, + want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=0s --handover-timeout=10m0s", + }, + { + name: "with profile", + opts: client.ClientOptions{ClusterID: "abc-123", Profile: "my-profile"}, + want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=0s --profile=my-profile", + }, + { + name: "with liteswap", + opts: client.ClientOptions{ClusterID: "abc-123", Liteswap: "test-env"}, + want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=0s --liteswap=test-env", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.opts.ToProxyCommand() + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +}