Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 3 additions & 17 deletions experimental/ssh/cmd/connect.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package ssh

import (
"errors"
"time"

"github.com/databricks/cli/cmd/root"
Expand Down Expand Up @@ -82,22 +81,6 @@ the SSH server and handling the connection proxy.
cmd.RunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
wsClient := cmdctx.WorkspaceClient(ctx)

if !proxyMode && clusterID == "" && connectionName == "" {
return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)")
}

if accelerator != "" && connectionName == "" {
return errors.New("--accelerator flag can only be used with serverless compute (--name flag)")
}

// Remove when we add support for serverless CPU
if connectionName != "" && accelerator == "" {
return errors.New("--name flag requires --accelerator to be set (for now we only support serverless GPU compute)")
}

// TODO: validate connectionName if provided

opts := client.ClientOptions{
Profile: wsClient.Config.Profile,
ClusterID: clusterID,
Expand All @@ -120,6 +103,9 @@ the SSH server and handling the connection proxy.
SkipSettingsCheck: skipSettingsCheck,
AdditionalArgs: args,
}
if err := opts.Validate(); err != nil {
return err
}
return client.Run(ctx, wsClient, opts)
}

Expand Down
27 changes: 23 additions & 4 deletions experimental/ssh/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"os/exec"
"os/signal"
"path/filepath"
"regexp"
"strconv"
"strings"
"syscall"
Expand All @@ -38,6 +39,8 @@ var sshServerBootstrapScript string

var errServerMetadata = errors.New("server metadata error")

var connectionNameRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`)

const (
sshServerTaskKey = "start_ssh_server"
serverlessEnvironmentKey = "ssh_tunnel_serverless"
Expand Down Expand Up @@ -97,6 +100,26 @@ type ClientOptions struct {
SkipSettingsCheck bool
}

func (o *ClientOptions) Validate() error {
if !o.ProxyMode && o.ClusterID == "" && o.ConnectionName == "" {
return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)")
}
if o.Accelerator != "" && o.ConnectionName == "" {
return errors.New("--accelerator flag can only be used with serverless compute (--name flag)")
}
// TODO: Remove when we add support for serverless CPU
if o.ConnectionName != "" && o.Accelerator == "" {
return errors.New("--name flag requires --accelerator to be set (for now we only support serverless GPU compute)")
}
if o.ConnectionName != "" && !connectionNameRegex.MatchString(o.ConnectionName) {
return fmt.Errorf("connection name %q must consist of letters, numbers, dashes, and underscores", o.ConnectionName)
}
if o.IDE != "" && o.IDE != VSCodeOption && o.IDE != CursorOption {
return fmt.Errorf("invalid IDE value: %q, expected %q or %q", o.IDE, VSCodeOption, CursorOption)
}
return nil
}

func (o *ClientOptions) IsServerlessMode() bool {
return o.ClusterID == "" && o.ConnectionName != ""
}
Expand Down Expand Up @@ -287,10 +310,6 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
}

func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error {
if opts.IDE != VSCodeOption && opts.IDE != CursorOption {
return fmt.Errorf("invalid IDE value: %s, expected '%s' or '%s'", opts.IDE, VSCodeOption, CursorOption)
}

connectionName := opts.SessionIdentifier()
if connectionName == "" {
return errors.New("connection name is required for IDE integration")
Expand Down
152 changes: 152 additions & 0 deletions experimental/ssh/internal/client/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
package client_test

import (
"fmt"
"os"
"testing"
"time"

"github.com/databricks/cli/experimental/ssh/internal/client"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestValidate(t *testing.T) {
tests := []struct {
name string
opts client.ClientOptions
wantErr string
}{
{
name: "no cluster or connection name",
opts: client.ClientOptions{},
wantErr: "please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)",
},
{
name: "proxy mode skips cluster/name check",
opts: client.ClientOptions{ProxyMode: true},
},
{
name: "cluster ID only",
opts: client.ClientOptions{ClusterID: "abc-123"},
},
{
name: "accelerator without connection name",
opts: client.ClientOptions{ClusterID: "abc-123", Accelerator: "GPU_1xA10"},
wantErr: "--accelerator flag can only be used with serverless compute (--name flag)",
},
{
name: "connection name without accelerator",
opts: client.ClientOptions{ConnectionName: "my-conn"},
wantErr: "--name flag requires --accelerator to be set (for now we only support serverless GPU compute)",
},
{
name: "invalid connection name characters",
opts: client.ClientOptions{ConnectionName: "my conn!", Accelerator: "GPU_1xA10"},
wantErr: `connection name "my conn!" must consist of letters, numbers, dashes, and underscores`,
},
{
name: "connection name with leading dash",
opts: client.ClientOptions{ConnectionName: "-my-conn", Accelerator: "GPU_1xA10"},
wantErr: `connection name "-my-conn" must consist of letters, numbers, dashes, and underscores`,
},
{
name: "valid connection name with accelerator",
opts: client.ClientOptions{ConnectionName: "my-conn_1", Accelerator: "GPU_1xA10"},
},
{
name: "both cluster ID and connection name",
opts: client.ClientOptions{ClusterID: "abc-123", ConnectionName: "my-conn", Accelerator: "GPU_1xA10"},
},
{
name: "proxy mode with invalid connection name",
opts: client.ClientOptions{ProxyMode: true, ConnectionName: "bad name!", Accelerator: "GPU_1xA10"},
wantErr: `connection name "bad name!" must consist of letters, numbers, dashes, and underscores`,
},
{
name: "invalid IDE value",
opts: client.ClientOptions{ClusterID: "abc-123", IDE: "vim"},
wantErr: `invalid IDE value: "vim", expected "vscode" or "cursor"`,
},
{
name: "valid IDE vscode",
opts: client.ClientOptions{ClusterID: "abc-123", IDE: "vscode"},
},
{
name: "valid IDE cursor",
opts: client.ClientOptions{ClusterID: "abc-123", IDE: "cursor"},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.opts.Validate()
if tt.wantErr == "" {
assert.NoError(t, err)
} else {
assert.EqualError(t, err, tt.wantErr)
}
})
}
}

func TestToProxyCommand(t *testing.T) {
exe, err := os.Executable()
require.NoError(t, err)
quoted := fmt.Sprintf("%q", exe)

tests := []struct {
name string
opts client.ClientOptions
want string
}{
{
name: "dedicated cluster",
opts: client.ClientOptions{ClusterID: "abc-123", ShutdownDelay: 5 * time.Minute},
want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=5m0s",
},
{
name: "dedicated cluster with auto-start",
opts: client.ClientOptions{ClusterID: "abc-123", AutoStartCluster: true, ShutdownDelay: 5 * time.Minute},
want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=true --shutdown-delay=5m0s",
},
{
name: "serverless",
opts: client.ClientOptions{ConnectionName: "my-conn", ShutdownDelay: 2 * time.Minute},
want: quoted + " ssh connect --proxy --name=my-conn --shutdown-delay=2m0s",
},
{
name: "serverless with accelerator",
opts: client.ClientOptions{ConnectionName: "my-conn", Accelerator: "GPU_1xA10", ShutdownDelay: 2 * time.Minute},
want: quoted + " ssh connect --proxy --name=my-conn --shutdown-delay=2m0s --accelerator=GPU_1xA10",
},
{
name: "with metadata",
opts: client.ClientOptions{ClusterID: "abc-123", ServerMetadata: "user,2222,abc-123"},
want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=0s --metadata=user,2222,abc-123",
},
{
name: "with handover timeout",
opts: client.ClientOptions{ClusterID: "abc-123", HandoverTimeout: 10 * time.Minute},
want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=0s --handover-timeout=10m0s",
},
{
name: "with profile",
opts: client.ClientOptions{ClusterID: "abc-123", Profile: "my-profile"},
want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=0s --profile=my-profile",
},
{
name: "with liteswap",
opts: client.ClientOptions{ClusterID: "abc-123", Liteswap: "test-env"},
want: quoted + " ssh connect --proxy --cluster=abc-123 --auto-start-cluster=false --shutdown-delay=0s --liteswap=test-env",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.opts.ToProxyCommand()
require.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}