From 5c4633cad6ccd480521a40cd30bc875cde2a5021 Mon Sep 17 00:00:00 2001 From: Brad Swain Date: Mon, 23 Feb 2026 15:06:41 -0600 Subject: [PATCH] add setup-claude command --- README.md | 41 + dropkit/main.py | 697 +++++++++++++++- tests/test_setup_claude.py | 1594 ++++++++++++++++++++++++++++++++++++ 3 files changed, 2296 insertions(+), 36 deletions(-) create mode 100644 tests/test_setup_claude.py diff --git a/README.md b/README.md index 786bbe4..85edbbe 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,7 @@ Commands: off Power off a droplet (requires confirmation). hibernate Hibernate a droplet (snapshot and destroy to save costs). wake Wake a hibernated droplet (restore from snapshot). + setup-claude Set up Claude Code on an existing droplet. enable-tailscale Enable Tailscale VPN on an existing droplet. list-ssh-keys List SSH keys registered via dropkit. add-ssh-key Add or import an SSH public key to DigitalOcean. @@ -180,6 +181,46 @@ dropkit destroy my-droplet **Note:** Snapshots are billed at $0.06/GB/month, which is typically much cheaper than keeping a droplet running. +### Claude Code Setup + +Set up [Claude Code](https://claude.ai/claude-code) on a droplet for sandboxed AI development: + +```bash +# Create a droplet, then set up Claude Code +dropkit create my-sandbox +dropkit setup-claude my-sandbox +``` + +After installing Claude Code, an interactive prompt lets you choose what to sync: + +``` +What to set up on the droplet: + 1. Global CLAUDE.md + 2. Settings (model prefs, UI, behavior) + 3. GitHub token + 4. Marketplace: claude-plugins-official + +Enter numbers to sync (comma-separated), or "all" [all]: +``` + +Use `--sync-all` to skip the prompt and sync everything (useful for scripting): + +```bash +dropkit setup-claude my-sandbox --sync-all +``` + +**What it does:** +1. Installs Claude Code via the official installer +2. Prompts you to select which settings to sync (CLAUDE.md, settings, GitHub token, marketplaces) +3. Syncs your selected items to the droplet +4. Runs `claude /login` on the droplet for you to authenticate your Claude subscription + +**GitHub CLI authentication:** +```bash +export GITHUB_TOKEN=ghp_... +dropkit setup-claude my-sandbox +``` + ### Cloud-Init Customization Edit `~/.config/dropkit/cloud-init.yaml` to customize user setup, package installation, firewall rules, and shell configuration. The template uses Jinja2 syntax with variables `{{ username }}` and `{{ ssh_keys }}`. diff --git a/dropkit/main.py b/dropkit/main.py index 26f0bb2..c39a1db 100644 --- a/dropkit/main.py +++ b/dropkit/main.py @@ -1,10 +1,15 @@ """Main CLI application for dropkit.""" +import contextlib +import dataclasses import json +import os import re +import shlex import shutil import subprocess import sys +import tempfile import time from pathlib import Path from typing import Any @@ -42,6 +47,13 @@ ) console = Console() +SSH_OPTS = [ + "-o", + "StrictHostKeyChecking=no", + "-o", + "UserKnownHostsFile=/dev/null", +] + @app.callback() def main_callback(): @@ -644,10 +656,7 @@ def wait_for_cloud_init(ssh_hostname: str, verbose: bool = False) -> tuple[bool, result = subprocess.run( [ "ssh", - "-o", - "StrictHostKeyChecking=no", - "-o", - "UserKnownHostsFile=/dev/null", + *SSH_OPTS, "-o", "ConnectTimeout=5", "-o", @@ -825,10 +834,7 @@ def run_tailscale_up(ssh_hostname: str, verbose: bool = False) -> str | None: result = subprocess.run( [ "ssh", - "-o", - "StrictHostKeyChecking=no", - "-o", - "UserKnownHostsFile=/dev/null", + *SSH_OPTS, "-o", "ConnectTimeout=10", ssh_hostname, @@ -885,10 +891,7 @@ def tailscale_logout(ssh_hostname: str, verbose: bool = False) -> bool: result = subprocess.run( [ "ssh", - "-o", - "StrictHostKeyChecking=no", - "-o", - "UserKnownHostsFile=/dev/null", + *SSH_OPTS, "-o", "ConnectTimeout=10", ssh_hostname, @@ -1116,10 +1119,7 @@ def get_tailscale_ip(ssh_hostname: str) -> str | None: result = subprocess.run( [ "ssh", - "-o", - "StrictHostKeyChecking=no", - "-o", - "UserKnownHostsFile=/dev/null", + *SSH_OPTS, "-o", "ConnectTimeout=5", "-o", @@ -1171,10 +1171,7 @@ def wait_for_tailscale_ip( result = subprocess.run( [ "ssh", - "-o", - "StrictHostKeyChecking=no", - "-o", - "UserKnownHostsFile=/dev/null", + *SSH_OPTS, "-o", "ConnectTimeout=5", "-o", @@ -1240,10 +1237,7 @@ def lock_down_to_tailscale(ssh_hostname: str, verbose: bool = False) -> bool: result = subprocess.run( [ "ssh", - "-o", - "StrictHostKeyChecking=no", - "-o", - "UserKnownHostsFile=/dev/null", + *SSH_OPTS, "-o", "ConnectTimeout=10", ssh_hostname, @@ -1295,10 +1289,7 @@ def verify_tailscale_ssh( result = subprocess.run( [ "ssh", - "-o", - "StrictHostKeyChecking=no", - "-o", - "UserKnownHostsFile=/dev/null", + *SSH_OPTS, "-o", "ConnectTimeout=10", "-o", @@ -1338,10 +1329,7 @@ def check_tailscale_installed(ssh_hostname: str, verbose: bool = False) -> bool: result = subprocess.run( [ "ssh", - "-o", - "StrictHostKeyChecking=no", - "-o", - "UserKnownHostsFile=/dev/null", + *SSH_OPTS, "-o", "ConnectTimeout=10", "-o", @@ -1382,10 +1370,7 @@ def install_tailscale_on_droplet(ssh_hostname: str, verbose: bool = False) -> bo result = subprocess.run( [ "ssh", - "-o", - "StrictHostKeyChecking=no", - "-o", - "UserKnownHostsFile=/dev/null", + *SSH_OPTS, "-o", "ConnectTimeout=10", ssh_hostname, @@ -4443,6 +4428,646 @@ def delete_ssh_key_cmd( raise typer.Exit(1) +# --- setup-claude helpers --- + + +@dataclasses.dataclass(frozen=True) +class SyncChoice: + """One selectable item in the setup-claude sync prompt.""" + + label: str + key: str # "claude_md", "settings", "github_token", or "marketplace:" + + +_GITHUB_TOKEN_PREFIXES = ("ghp_", "gho_", "github_pat_", "ghs_") + + +def _discover_sync_choices() -> list[SyncChoice]: + """Build the list of sync items available on this machine. + + Always includes CLAUDE.md and Settings when the local source exists. + Adds GitHub token when ``GITHUB_TOKEN`` is set with a valid prefix. + Discovers marketplaces from ``~/.claude/plugins/known_marketplaces.json``. + """ + choices: list[SyncChoice] = [] + + if Path("~/.claude/CLAUDE.md").expanduser().exists(): + choices.append(SyncChoice("Global CLAUDE.md", "claude_md")) + if Path("~/.claude/settings.json").expanduser().exists(): + choices.append(SyncChoice("Settings (model prefs, UI, behavior)", "settings")) + + pat = os.environ.get("GITHUB_TOKEN", "") + if pat and pat.startswith(_GITHUB_TOKEN_PREFIXES): + choices.append(SyncChoice("GitHub token", "github_token")) + + mp_path = Path("~/.claude/plugins/known_marketplaces.json").expanduser() + if mp_path.is_file(): + try: + mp_data = json.loads(mp_path.read_text(encoding="utf-8")) + choices.extend( + SyncChoice(f"Marketplace: {name}", f"marketplace:{name}") + for name in sorted(mp_data) + ) + except (OSError, json.JSONDecodeError) as e: + console.print( + f"[yellow]Warning: Could not load marketplaces from {mp_path}: {e}[/yellow]" + ) + + return choices + + +def _prompt_sync_selection(choices: list[SyncChoice]) -> set[str]: + """Show an interactive numbered list and return the selected keys. + + Accepts comma-separated numbers, ``"all"`` (default), or ``"none"``. + Returns an empty set when the user types ``"none"`` or the list is empty. + """ + if not choices: + return set() + + console.print("\n[bold]What to set up on the droplet:[/bold]") + for i, choice in enumerate(choices, 1): + console.print(f" {i}. {choice.label}") + + answer = Prompt.ask( + '\nEnter numbers to sync (comma-separated), or "all"', + default="all", + ) + answer = answer.strip().lower() + + if answer == "all": + return {c.key for c in choices} + if answer == "none": + return set() + + selected: set[str] = set() + for raw_part in answer.split(","): + part = raw_part.strip() + if not part: + continue + try: + idx = int(part) + except ValueError: + console.print( + f'[red]Invalid input: "{part}". Expected a number, "all", or "none".[/red]' + ) + raise typer.Exit(1) + if 1 <= idx <= len(choices): + selected.add(choices[idx - 1].key) + else: + console.print( + f"[red]Invalid choice: {idx}. Must be between 1 and {len(choices)}.[/red]" + ) + raise typer.Exit(1) + + return selected + + +# Keys safe to sync from settings.json — anything NOT on this list is stripped. +# Excludes: permissions (local paths), env (secrets), hooks/apiKeyHelper/ +# awsAuthRefresh/awsCredentialExport/otelHeadersHelper/fileSuggestion (local +# scripts), statusLine (may contain command), sandbox (internal infra), etc. +_SETTINGS_SAFE_KEYS: set[str] = { + # Model preferences + "model", + "effortLevel", + "fastMode", + "availableModels", + "alwaysThinkingEnabled", + # UI / terminal + "language", + "prefersReducedMotion", + "showTurnDuration", + "spinnerTipsEnabled", + "spinnerTipsOverride", + "spinnerVerbs", + "terminalProgressBarEnabled", + "outputStyle", + "attorneyMode", + "teammateMode", + # Plugins + "enabledPlugins", + "extraKnownMarketplaces", + "strictKnownMarketplaces", + "skippedMarketplaces", + "skippedPlugins", + "pluginConfigs", + # MCP approval toggles (not server configs — those live in .mcp.json) + "enableAllProjectMcpServers", + "enabledMcpjsonServers", + "disabledMcpjsonServers", + # Behavioral + "respectGitignore", + "autoUpdatesChannel", + "disableAllHooks", + "cleanupPeriodDays", + "skipWebFetchPreflight", + # Git + "attribution", + "includeCoAuthoredBy", + # Admin / org + "forceLoginMethod", + "forceLoginOrgUUID", + "companyAnnouncements", +} + + +def _sanitize_settings(data: dict[str, Any]) -> dict[str, Any]: + """Return a copy of settings keeping only known-safe keys.""" + result = {k: v for k, v in data.items() if k in _SETTINGS_SAFE_KEYS} + if "extraKnownMarketplaces" in result and isinstance(result["extraKnownMarketplaces"], dict): + result["extraKnownMarketplaces"] = { + name: source + for name, source in result["extraKnownMarketplaces"].items() + if isinstance(source, dict) and source.get("type") not in ("file", "directory") + } + return result + + +def _sanitize_known_marketplaces( + data: dict[str, Any], + local_home: str, + remote_home: str, + marketplace_filter: set[str] | None = None, +) -> dict[str, Any]: + """Rewrite installLocation paths from local home to remote home. + + When *marketplace_filter* is set, only marketplace keys in the filter are + included in the output. + """ + result: dict[str, Any] = {} + for key, entry in data.items(): + if marketplace_filter is not None and key not in marketplace_filter: + continue + copied = dict(entry) + if "installLocation" in copied: + copied["installLocation"] = copied["installLocation"].replace( + local_home, remote_home, 1 + ) + result[key] = copied + return result + + +def _sanitize_installed_plugins( + data: dict[str, Any], + local_home: str, + remote_home: str, + marketplace_filter: set[str] | None = None, +) -> dict[str, Any]: + """Rewrite installPath for user-scope entries and strip local-scope entries. + + Supports the v2 format: ``{"version": 2, "plugins": {name: [entries...]}}``. + Each plugin value is a list of entry dicts with scope/installPath fields. + + When *marketplace_filter* is set, only plugins whose ``@source`` suffix + matches a name in the filter are included. + """ + plugins = data.get("plugins", {}) + result: dict[str, list[dict[str, Any]]] = {} + for key, entries in plugins.items(): + if marketplace_filter is not None: + # Plugin keys use format "plugin-name@marketplace-name" + source = key.rsplit("@", 1)[-1] if "@" in key else "" + if source not in marketplace_filter: + continue + kept: list[dict[str, Any]] = [] + for entry in entries: + if entry.get("scope") == "local": + continue + copied = dict(entry) + if "installPath" in copied: + copied["installPath"] = copied["installPath"].replace(local_home, remote_home, 1) + kept.append(copied) + if kept: + result[key] = kept + return {"version": data.get("version", 2), "plugins": result} + + +# https://code.claude.com/docs/en/setup +INSTALL_CLAUDE_CODE = "curl -fsSL https://claude.ai/install.sh | bash" + + +def _ssh_cmd(ssh_hostname: str, remote_cmd: str) -> list[str]: + """Build an SSH command with standard options.""" + return ["ssh", *SSH_OPTS, "-o", "ConnectTimeout=10", ssh_hostname, remote_cmd] + + +def _decode_output(data: bytes) -> str: + """Decode subprocess output, replacing non-UTF-8 bytes.""" + return data.decode("utf-8", errors="replace").strip() + + +def _last_line(text: str, fallback: str = "unknown error") -> str: + """Return the last line of text, or a fallback if empty.""" + return text.splitlines()[-1] if text else fallback + + +def _install_claude_code(ssh_hostname: str, verbose: bool) -> bool: + """Install Claude Code via the native installer. Returns True on success.""" + # Use login shell so PATH includes ~/.local/bin where Claude installs + version_cmd = "bash -lc 'claude --version'" + try: + result = subprocess.run( + _ssh_cmd(ssh_hostname, version_cmd), + capture_output=True, + timeout=30, + ) + except (subprocess.TimeoutExpired, subprocess.SubprocessError) as e: + console.print(f"[red]FAILED[/red] ({e})") + return False + # Exit code 255 means SSH itself failed (connection refused, auth error, etc.) + if result.returncode == 255: + reason = _last_line(_decode_output(result.stderr), "SSH connection failed") + console.print(f"[red]FAILED[/red] ({reason})") + return False + if result.returncode == 0: + version_str = _decode_output(result.stdout) + console.print(f"[green]done[/green] (already installed: {version_str})") + return True + + if result.returncode == 127: + if verbose: + console.print("[dim]claude not found, installing...[/dim]") + else: + stderr = _decode_output(result.stderr) + console.print( + f"[dim]claude exited with code {result.returncode}" + f" (reinstalling...){': ' + _last_line(stderr) if stderr else ''}[/dim]" + ) + try: + result = subprocess.run( + _ssh_cmd(ssh_hostname, INSTALL_CLAUDE_CODE), + capture_output=True, + timeout=300, + ) + except (subprocess.TimeoutExpired, subprocess.SubprocessError) as e: + console.print(f"[red]FAILED[/red] ({e})") + return False + if result.returncode != 0: + output = _decode_output(result.stderr) or _decode_output(result.stdout) + console.print("[red]FAILED[/red]") + if output: + console.print(f"[dim] {output if verbose else _last_line(output)}[/dim]") + else: + console.print("[dim] Install failed with no output. Re-run with --verbose.[/dim]") + return False + + # Get installed version (best-effort, install already succeeded) + version_str = "" + try: + ver_result = subprocess.run( + _ssh_cmd(ssh_hostname, version_cmd), + capture_output=True, + timeout=30, + ) + if ver_result.returncode == 0: + version_str = _decode_output(ver_result.stdout) + except (subprocess.TimeoutExpired, subprocess.SubprocessError): + pass # Best-effort: install already succeeded, version display is cosmetic + suffix = f" ({version_str})" if version_str else "" + console.print(f"[green]done[/green]{suffix}") + return True + + +def _auth_github(ssh_hostname: str, verbose: bool) -> None: + """Authenticate gh CLI on the droplet using GITHUB_TOKEN.""" + pat = os.environ.get("GITHUB_TOKEN", "") + if not pat: + console.print("[yellow]skipped[/yellow] (GITHUB_TOKEN not set)") + return + if not pat.startswith(_GITHUB_TOKEN_PREFIXES): + console.print("[yellow]skipped[/yellow] (GITHUB_TOKEN does not look like a GitHub token)") + return + + try: + result = subprocess.run( + _ssh_cmd(ssh_hostname, "gh auth login --with-token"), + input=pat.encode(), + capture_output=True, + timeout=60, + ) + except (subprocess.TimeoutExpired, subprocess.SubprocessError) as e: + console.print(f"[yellow]FAILED[/yellow] ({e})") + return + if result.returncode == 255: + stderr = _decode_output(result.stderr) + console.print(f"[yellow]FAILED[/yellow] (SSH connection lost: {_last_line(stderr)})") + return + if result.returncode != 0: + stderr = _decode_output(result.stderr) + if "command not found" in stderr: + console.print("[yellow]FAILED[/yellow] (gh CLI not installed on droplet)") + else: + console.print(f"[yellow]FAILED[/yellow] ({_last_line(stderr)})") + if verbose and stderr: + console.print(f"[dim] {stderr}[/dim]") + return + + console.print("[green]done[/green]") + + +def _sync_settings( + ssh_hostname: str, + verbose: bool, + remote_home: str, + selected: set[str] | None = None, +) -> None: + """Rsync selected Claude settings to the droplet. + + *selected* controls which items to sync. When ``None``, everything is + synced (backwards-compatible / ``--sync-all`` behaviour). Otherwise only + the items whose keys appear in *selected* are synced. + """ + # Build sync path list dynamically based on selection. + # Each entry: (display_label, local_path_str, remote_path, excludes) + sync_paths: list[tuple[str, str, str, list[str]]] = [] + if selected is None or "claude_md" in selected: + sync_paths.append(("CLAUDE.md", "~/.claude/CLAUDE.md", ".claude/CLAUDE.md", [])) + if selected is None or "settings" in selected: + sync_paths.append(("settings.json", "~/.claude/settings.json", ".claude/settings.json", [])) + + # Marketplace items: sync filtered JSON manifests + per-marketplace dirs. + marketplace_names: set[str] = set() + if selected is not None: + marketplace_names = { + k.removeprefix("marketplace:") for k in selected if k.startswith("marketplace:") + } + has_marketplace = selected is None or bool(marketplace_names) + if has_marketplace: + mp_filter = marketplace_names or None # None means "all" + sync_paths.append( + ( + "known_marketplaces.json", + "~/.claude/plugins/known_marketplaces.json", + ".claude/plugins/known_marketplaces.json", + [], + ) + ) + sync_paths.append( + ( + "installed_plugins.json", + "~/.claude/plugins/installed_plugins.json", + ".claude/plugins/installed_plugins.json", + [], + ) + ) + # Per-marketplace dirs (cache + marketplaces). When filtering, only + # sync the selected marketplace subdirectories. + if mp_filter: + for name in sorted(mp_filter): + sync_paths.append( + ( + f"marketplace:{name}", + f"~/.claude/plugins/marketplaces/{name}/", + f".claude/plugins/marketplaces/{name}/", + [".git"], + ) + ) + sync_paths.append( + ( + f"cache:{name}", + f"~/.claude/plugins/cache/{name}/", + f".claude/plugins/cache/{name}/", + [".git"], + ) + ) + else: + sync_paths.append( + ( + "plugins/", + "~/.claude/plugins/", + ".claude/plugins/", + ["known_marketplaces.json", "installed_plugins.json", ".git"], + ) + ) + else: + mp_filter = None + + # Resolve which local paths actually exist before doing any remote work. + to_sync: list[tuple[str, str, Path, list[str]]] = [] + for _label, local_path, remote_path, excludes in sync_paths: + expanded = Path(local_path).expanduser() + if expanded.exists(): + to_sync.append((local_path, remote_path, expanded, excludes)) + elif verbose: + console.print(f"[dim] skipping {local_path} (not found locally)[/dim]") + + if not to_sync: + console.print("[yellow]skipped[/yellow] (no local settings found)") + return + + # Ensure remote directories exist. + dirs_to_create = ".claude .claude/plugins" + if has_marketplace: + dirs_to_create += " .claude/plugins/marketplaces .claude/plugins/cache" + try: + mkdir_result = subprocess.run( + _ssh_cmd(ssh_hostname, f"mkdir -p {dirs_to_create}"), + capture_output=True, + timeout=30, + ) + except (subprocess.TimeoutExpired, subprocess.SubprocessError) as e: + console.print(f"[yellow]FAILED[/yellow] (mkdir: {e})") + return + if mkdir_result.returncode != 0: + stderr = _decode_output(mkdir_result.stderr) + console.print(f"[yellow]FAILED[/yellow] (mkdir: {_last_line(stderr)})") + return + + rsync_ssh = "ssh " + " ".join(shlex.quote(o) for o in SSH_OPTS) + local_home = str(Path.home()) + synced = 0 + failures: list[str] = [] + for local_path, remote_path, expanded, excludes in to_sync: + # Sanitize JSON files: strip sensitive keys / rewrite paths before syncing. + tmp_path: Path | None = None + rsync_source = str(expanded) + ("/" if expanded.is_dir() else "") + if expanded.name in ("settings.json", "known_marketplaces.json", "installed_plugins.json"): + try: + data = json.loads(expanded.read_text(encoding="utf-8")) + if expanded.name == "settings.json": + sanitized = _sanitize_settings(data) + elif expanded.name == "known_marketplaces.json": + sanitized = _sanitize_known_marketplaces( + data, local_home, remote_home, marketplace_filter=mp_filter + ) + elif expanded.name == "installed_plugins.json": + sanitized = _sanitize_installed_plugins( + data, local_home, remote_home, marketplace_filter=mp_filter + ) + prefix = f"dropkit-{expanded.stem}-" + tmp_fd, tmp_name = tempfile.mkstemp(suffix=".json", prefix=prefix) + os.close(tmp_fd) + tmp_path = Path(tmp_name) + tmp_path.write_text(json.dumps(sanitized, indent=2), encoding="utf-8") + rsync_source = str(tmp_path) + except (OSError, json.JSONDecodeError) as e: + if verbose: + console.print(f"[dim] skipping {local_path} (sanitize failed: {e})[/dim]") + failures.append(f"{local_path} (sanitize failed: {e})") + if tmp_path: + with contextlib.suppress(OSError): + tmp_path.unlink() + continue + + try: + rsync_cmd = [ + "rsync", + "-az", + "--no-perms", + *[f"--exclude={e}" for e in excludes], + "-e", + rsync_ssh, + rsync_source, + f"{ssh_hostname}:{remote_path}", + ] + + if verbose: + console.print(f"[dim] {' '.join(rsync_cmd)}[/dim]") + + try: + result = subprocess.run(rsync_cmd, capture_output=True, timeout=300) + except (subprocess.TimeoutExpired, subprocess.SubprocessError, OSError) as e: + failures.append(f"{local_path} ({e})") + continue + if result.returncode != 0: + stderr = _decode_output(result.stderr) + failures.append(f"{local_path} ({_last_line(stderr)})") + if verbose and stderr: + console.print(f"[dim] {local_path}: {stderr}[/dim]") + continue + synced += 1 + finally: + if tmp_path: + with contextlib.suppress(OSError): + tmp_path.unlink() + + if failures: + console.print(f"[yellow]FAILED[/yellow] ({', '.join(failures)})") + else: + console.print(f"[green]done[/green] ({synced} path(s))") + + +def _open_auth_session(ssh_hostname: str, verbose: bool) -> bool: + """Run ``claude /login`` on the droplet for Claude Code authentication. + + Returns True if the session exited cleanly, False otherwise. + """ + console.print( + "\n[bold]Authenticate Claude Code[/bold]" + "\n 1. Visit the link shown below and complete sign-in" + "\n 2. Paste the code back into this terminal\n" + ) + + # Run `claude /login` inside an ephemeral temp directory so that Claude + # Code doesn't trust the home directory tree. The trap ensures cleanup + # even on signals. Using `bash -lc` bypasses zsh/p10k first-time setup. + login_cmd = ( + "bash -lc 'dir=$(mktemp -d) && " + "trap '\"'\"'rm -rf \"$dir\"'\"'\"' EXIT && " + 'cd "$dir" && claude /login\'' + ) + + ssh_cmd = [ + "ssh", + "-t", + *SSH_OPTS, + "-o", + "ConnectTimeout=10", + ssh_hostname, + login_cmd, + ] + + if verbose: + console.print(f"[dim] {' '.join(ssh_cmd)}[/dim]") + + try: + result = subprocess.run(ssh_cmd) + except KeyboardInterrupt: + console.print( + "\n[dim]Session interrupted. If you haven't completed authentication, " + "re-run: dropkit setup-claude [/dim]" + ) + raise typer.Exit(code=130) + except (OSError, subprocess.SubprocessError) as e: + console.print(f"[red]Failed to launch SSH: {e}[/red]") + raise typer.Exit(code=1) + if result.returncode == 255: + console.print( + "[red]SSH connection failed. Check that the droplet is running and accessible.[/red]" + ) + return False + # Non-zero exit from `claude /login` is normal — e.g. the user cancelled. + return True + + +@app.command(name="setup-claude") +def setup_claude( + droplet_name: str = typer.Argument( + ..., help="Name of the droplet to configure.", autocompletion=complete_droplet_name + ), + sync_all: bool = typer.Option(False, "--sync-all", help="Sync all settings without prompting."), + verbose: bool = typer.Option(False, "--verbose", "-v", help="Show detailed output."), +) -> None: + """Set up Claude Code on an existing droplet. + + Installs Claude Code, shows an interactive prompt to select which settings + to sync (CLAUDE.md, settings, GitHub token, marketplaces), then opens an + SSH session for Claude subscription login. + + Use --sync-all to skip the prompt and sync everything (useful for scripts). + + Only droplets tagged with owner: can be configured. + """ + _, api = load_config_and_api() + droplet, username = find_user_droplet(api, droplet_name) + if not droplet: + tag = get_user_tag(username) + console.print( + f"[red]Error: Droplet '{droplet_name}' not found (filtered by tag: {tag})[/red]" + ) + raise typer.Exit(code=1) + + ssh_hostname = get_ssh_hostname(droplet_name) + + console.print(f"Setting up Claude Code on [cyan]{ssh_hostname}[/cyan]...\n") + + # Step 1: Install (fatal) + console.print(" Installing Claude Code... ", end="") + if not _install_claude_code(ssh_hostname, verbose): + console.print("\n[red]Error: Claude Code installation failed. Aborting.[/red]") + raise typer.Exit(code=1) + + # Step 2: Interactive prompt OR --sync-all + if sync_all: + selected: set[str] | None = None # None = sync everything + run_github = True + else: + choices = _discover_sync_choices() + selected_keys = _prompt_sync_selection(choices) if choices else set() + run_github = "github_token" in selected_keys + # Remove github_token from the set — it's handled by _auth_github, not _sync_settings + non_github = selected_keys - {"github_token"} + selected = non_github # empty set = sync nothing via _sync_settings + + # Step 3: GitHub auth (only when selected or --sync-all) + if run_github: + console.print(" Authenticating GitHub CLI... ", end="") + _auth_github(ssh_hostname, verbose) + + # Step 4: Sync settings (based on selection) + if selected is None or selected: + console.print(" Syncing settings... ", end="") + _sync_settings(ssh_hostname, verbose, remote_home=f"/home/{username}", selected=selected) + elif verbose: + console.print(" Syncing settings... [dim]skipped[/dim] (nothing selected)") + + # Step 5: Authenticate Claude Code (always last) + console.print(" Authenticating Claude Code... [dim]a link will appear below[/dim]") + if not _open_auth_session(ssh_hostname, verbose): + raise typer.Exit(code=1) + + @app.command() def version(): """Show the version of dropkit.""" diff --git a/tests/test_setup_claude.py b/tests/test_setup_claude.py new file mode 100644 index 0000000..b1adc3c --- /dev/null +++ b/tests/test_setup_claude.py @@ -0,0 +1,1594 @@ +"""Tests for setup-claude command helpers.""" + +import json +import subprocess +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +import typer + +from dropkit.main import ( + _GITHUB_TOKEN_PREFIXES, + _SETTINGS_SAFE_KEYS, + SSH_OPTS, + SyncChoice, + _auth_github, + _discover_sync_choices, + _install_claude_code, + _open_auth_session, + _prompt_sync_selection, + _sanitize_installed_plugins, + _sanitize_known_marketplaces, + _sanitize_settings, + _ssh_cmd, + _sync_settings, + setup_claude, +) + + +class TestSSHOpts: + """Tests for SSH_OPTS constant.""" + + def test_contains_strict_host_key_checking(self): + assert "-o" in SSH_OPTS + assert "StrictHostKeyChecking=no" in SSH_OPTS + + def test_contains_user_known_hosts_file(self): + assert "UserKnownHostsFile=/dev/null" in SSH_OPTS + + def test_does_not_contain_connect_timeout(self): + """ConnectTimeout is per-call-site, not in the shared constant.""" + for opt in SSH_OPTS: + assert not opt.startswith("ConnectTimeout") + + +class TestSSHCmd: + """Tests for _ssh_cmd helper.""" + + def test_basic_command(self): + cmd = _ssh_cmd("dropkit.test", "echo hello") + assert cmd[0] == "ssh" + assert "dropkit.test" in cmd + assert "echo hello" in cmd + assert "ConnectTimeout=10" in cmd + + def test_includes_ssh_opts(self): + cmd = _ssh_cmd("dropkit.test", "ls") + for opt in SSH_OPTS: + assert opt in cmd + + +class TestInstallClaudeCode: + """Tests for _install_claude_code function.""" + + @patch("dropkit.main.subprocess.run") + def test_already_installed(self, mock_run): + """Skip install if claude --version succeeds.""" + mock_run.return_value = MagicMock( + returncode=0, + stdout=b"1.0.0", + ) + assert _install_claude_code("dropkit.test", verbose=False) is True + # Only called once (version check), no install + mock_run.assert_called_once() + cmd = mock_run.call_args[0][0] + assert "claude --version" in cmd[-1] + + @patch("dropkit.main.subprocess.run") + def test_version_check_uses_login_shell(self, mock_run): + """Version check uses bash -lc to pick up PATH changes.""" + mock_run.return_value = MagicMock(returncode=0, stdout=b"1.0.0") + _install_claude_code("dropkit.test", verbose=False) + cmd = mock_run.call_args[0][0] + assert "bash -lc" in cmd[-1] + + @patch("dropkit.main.subprocess.run") + def test_install_success(self, mock_run): + """Install succeeds when claude not present.""" + # First call: version check fails; second: install succeeds; third: version after install + mock_run.side_effect = [ + MagicMock(returncode=1, stdout=b""), + MagicMock(returncode=0, stdout=b"", stderr=b""), + MagicMock(returncode=0, stdout=b"2.0.0"), + ] + assert _install_claude_code("dropkit.test", verbose=False) is True + assert mock_run.call_count == 3 + # Second call should be the install command + install_cmd = mock_run.call_args_list[1][0][0] + assert "curl" in install_cmd[-1] + + @patch("dropkit.main.subprocess.run") + def test_install_failure(self, mock_run): + """Returns False when install fails.""" + mock_run.side_effect = [ + MagicMock(returncode=1, stdout=b""), + MagicMock(returncode=1, stdout=b"", stderr=b"some error"), + ] + assert _install_claude_code("dropkit.test", verbose=False) is False + + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_install_failure_shows_last_stderr_line(self, mock_run, mock_console): + """Always shows last line of stderr on install failure.""" + mock_run.side_effect = [ + MagicMock(returncode=1, stdout=b""), + MagicMock(returncode=1, stdout=b"", stderr=b"line1\nline2\nactual error"), + ] + assert _install_claude_code("dropkit.test", verbose=False) is False + # Check that the last line of stderr was printed + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("actual error" in c for c in print_calls) + + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_install_failure_verbose_shows_full_stderr(self, mock_run, mock_console): + """Shows full stderr in verbose mode on failure.""" + mock_run.side_effect = [ + MagicMock(returncode=1, stdout=b""), + MagicMock(returncode=1, stdout=b"", stderr=b"line1\nline2\nactual error"), + ] + assert _install_claude_code("dropkit.test", verbose=True) is False + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("line1" in c and "line2" in c for c in print_calls) + + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_install_failure_falls_back_to_stdout(self, mock_run, mock_console): + """Shows stdout when stderr is empty on install failure.""" + mock_run.side_effect = [ + MagicMock(returncode=1, stdout=b""), + MagicMock(returncode=1, stdout=b"stdout hint here", stderr=b""), + ] + assert _install_claude_code("dropkit.test", verbose=False) is False + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("stdout hint" in c for c in print_calls) + + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_install_failure_no_output_shows_guidance(self, mock_run, mock_console): + """Shows actionable guidance when both stderr and stdout are empty.""" + mock_run.side_effect = [ + MagicMock(returncode=1, stdout=b""), + MagicMock(returncode=1, stdout=b"", stderr=b""), + ] + assert _install_claude_code("dropkit.test", verbose=False) is False + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("--verbose" in c for c in print_calls) + + @patch("dropkit.main.subprocess.run") + def test_uses_ssh_opts(self, mock_run): + """SSH commands include SSH_OPTS.""" + mock_run.return_value = MagicMock(returncode=0, stdout=b"1.0.0") + _install_claude_code("dropkit.test", verbose=False) + cmd = mock_run.call_args[0][0] + for opt in SSH_OPTS: + assert opt in cmd + + @patch("dropkit.main.subprocess.run") + def test_timeout_expired_returns_false(self, mock_run): + """Returns False on SSH timeout.""" + mock_run.side_effect = subprocess.TimeoutExpired("ssh", 30) + assert _install_claude_code("dropkit.test", verbose=False) is False + + @patch("dropkit.main.subprocess.run") + def test_subprocess_error_returns_false(self, mock_run): + """Returns False when subprocess raises SubprocessError.""" + mock_run.side_effect = subprocess.SubprocessError("ssh failed") + assert _install_claude_code("dropkit.test", verbose=False) is False + + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_ssh_connection_failure_exit_255(self, mock_run, mock_console): + """Exit code 255 is reported as SSH connection failure, not 'not installed'.""" + mock_run.return_value = MagicMock( + returncode=255, + stdout=b"", + stderr=b"ssh: connect to host example.com port 22: Connection refused", + ) + assert _install_claude_code("dropkit.test", verbose=False) is False + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("Connection refused" in c for c in print_calls) + # Should NOT proceed to install + mock_run.assert_called_once() + + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_verbose_logs_fallthrough_to_install(self, mock_run, mock_console): + """In verbose mode, logs why install was triggered on unexpected exit code.""" + mock_run.side_effect = [ + MagicMock(returncode=127, stdout=b"", stderr=b"command not found"), + MagicMock(returncode=0, stdout=b"", stderr=b""), # install succeeds + MagicMock(returncode=0, stdout=b"1.0.0"), # version check + ] + assert _install_claude_code("dropkit.test", verbose=True) is True + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("claude not found" in c for c in print_calls) + + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_install_phase_timeout_returns_false(self, mock_run, mock_console): + """Returns False when version check passes (not installed) but install times out.""" + mock_run.side_effect = [ + MagicMock(returncode=1, stdout=b""), # version check: not installed + subprocess.TimeoutExpired("ssh", 300), # install times out + ] + assert _install_claude_code("dropkit.test", verbose=False) is False + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("FAILED" in c for c in print_calls) + + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_post_install_version_check_failure_returns_true(self, mock_run, mock_console): + """Install succeeded but version probe failed still returns True.""" + mock_run.side_effect = [ + MagicMock(returncode=1, stdout=b""), # not installed + MagicMock(returncode=0, stdout=b"", stderr=b""), # install succeeds + MagicMock(returncode=1, stdout=b"", stderr=b""), # version check fails + ] + assert _install_claude_code("dropkit.test", verbose=False) is True + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("done" in c for c in print_calls) + + +class TestAuthGithub: + """Tests for _auth_github function.""" + + @patch.dict("os.environ", {}, clear=True) + @patch("dropkit.main.subprocess.run") + def test_skipped_without_token(self, mock_run): + """Skips when GITHUB_TOKEN not set.""" + _auth_github("dropkit.test", verbose=False) + mock_run.assert_not_called() + + @patch.dict("os.environ", {"GITHUB_TOKEN": ""}) + @patch("dropkit.main.subprocess.run") + def test_skipped_with_empty_token(self, mock_run): + """Skips when GITHUB_TOKEN is empty string.""" + _auth_github("dropkit.test", verbose=False) + mock_run.assert_not_called() + + @patch.dict("os.environ", {"GITHUB_TOKEN": "ghp_test123"}) + @patch("dropkit.main.subprocess.run") + def test_auth_success(self, mock_run): + """Pipes token into gh auth login.""" + mock_run.return_value = MagicMock(returncode=0) + _auth_github("dropkit.test", verbose=False) + mock_run.assert_called_once() + cmd = mock_run.call_args[0][0] + assert "gh auth login --with-token" in cmd[-1] + assert mock_run.call_args[1]["input"] == b"ghp_test123" + + @patch.dict("os.environ", {"GITHUB_TOKEN": "ghp_test123"}) + @patch("dropkit.main.subprocess.run") + def test_auth_failure_shows_stderr(self, mock_run): + """Shows actual stderr on auth failure, not a generic message.""" + mock_run.return_value = MagicMock( + returncode=1, + stderr=b"error connecting to api.github.com", + ) + # Should not raise + _auth_github("dropkit.test", verbose=False) + + @patch.dict("os.environ", {"GITHUB_TOKEN": "ghp_test123"}) + @patch("dropkit.main.subprocess.run") + def test_timeout_does_not_raise(self, mock_run): + """Timeout is caught gracefully.""" + mock_run.side_effect = subprocess.TimeoutExpired("ssh", 60) + # Should not raise + _auth_github("dropkit.test", verbose=False) + + @patch.dict("os.environ", {"GITHUB_TOKEN": "ghp_test123"}) + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_gh_not_found_reports_clearly(self, mock_run, mock_console): + """Reports 'gh CLI not installed' when gh is missing on remote.""" + mock_run.return_value = MagicMock( + returncode=127, + stderr=b"bash: gh: command not found", + ) + _auth_github("dropkit.test", verbose=False) + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("not installed" in c for c in print_calls) + + @patch.dict("os.environ", {"GITHUB_TOKEN": "not-a-github-token"}) + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_invalid_token_format_skipped(self, mock_run, mock_console): + """Skips auth when GITHUB_TOKEN doesn't look like a GitHub token.""" + _auth_github("dropkit.test", verbose=False) + mock_run.assert_not_called() + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("does not look like" in c for c in print_calls) + + @patch.dict("os.environ", {"GITHUB_TOKEN": "ghp_test123"}) + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_ssh_exit_255_reports_connection_lost(self, mock_run, mock_console): + """Exit code 255 is reported as SSH connection lost.""" + mock_run.return_value = MagicMock( + returncode=255, + stderr=b"ssh: connect to host example.com port 22: Connection refused", + ) + _auth_github("dropkit.test", verbose=False) + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("SSH connection lost" in c for c in print_calls) + + def test_all_github_token_prefixes_accepted(self): + """All known GitHub token prefixes pass validation.""" + for prefix in _GITHUB_TOKEN_PREFIXES: + assert f"{prefix}abc123".startswith(_GITHUB_TOKEN_PREFIXES) + + +class TestDiscoverSyncChoices: + """Tests for _discover_sync_choices function.""" + + def test_base_items_present(self, tmp_path): + """CLAUDE.md and Settings appear when local files exist.""" + claude_dir = tmp_path / ".claude" + claude_dir.mkdir() + (claude_dir / "CLAUDE.md").write_text("# test") + (claude_dir / "settings.json").write_text("{}") + + with patch("dropkit.main.Path.expanduser", return_value=tmp_path / ".claude" / "CLAUDE.md"): + # Need to patch per-path, so use a side-effect instead + pass + + def expand_side_effect(self): + local_path = str(self) + if "CLAUDE.md" in local_path: + return claude_dir / "CLAUDE.md" + if "settings.json" in local_path: + return claude_dir / "settings.json" + if "known_marketplaces.json" in local_path: + return tmp_path / "nonexistent" + return tmp_path / "nonexistent" + + with ( + patch.object(Path, "expanduser", expand_side_effect), + patch.dict("os.environ", {}, clear=True), + ): + choices = _discover_sync_choices() + + keys = [c.key for c in choices] + assert "claude_md" in keys + assert "settings" in keys + + def test_marketplace_discovery(self, tmp_path): + """Marketplace names come from known_marketplaces.json keys.""" + claude_dir = tmp_path / ".claude" + plugins_dir = claude_dir / "plugins" + plugins_dir.mkdir(parents=True) + mp_file = plugins_dir / "known_marketplaces.json" + mp_file.write_text(json.dumps({"mp-b": {}, "mp-a": {}})) + + def expand_side_effect(self): + local_path = str(self) + if "known_marketplaces.json" in local_path: + return mp_file + return tmp_path / "nonexistent" + + with ( + patch.object(Path, "expanduser", expand_side_effect), + patch.dict("os.environ", {}, clear=True), + ): + choices = _discover_sync_choices() + + mp_keys = [c.key for c in choices if c.key.startswith("marketplace:")] + assert mp_keys == ["marketplace:mp-a", "marketplace:mp-b"] # sorted + + def test_missing_marketplace_file(self, tmp_path): + """Returns base items when marketplace file doesn't exist.""" + + def expand_side_effect(self): + local_path = str(self) + if "CLAUDE.md" in local_path: + p = tmp_path / "CLAUDE.md" + p.write_text("# test") + return p + return tmp_path / "nonexistent" + + with ( + patch.object(Path, "expanduser", expand_side_effect), + patch.dict("os.environ", {}, clear=True), + ): + choices = _discover_sync_choices() + + keys = [c.key for c in choices] + assert "claude_md" in keys + assert not any(k.startswith("marketplace:") for k in keys) + + def test_malformed_json_graceful(self, tmp_path): + """Malformed marketplace JSON doesn't crash.""" + plugins_dir = tmp_path / ".claude" / "plugins" + plugins_dir.mkdir(parents=True) + (plugins_dir / "known_marketplaces.json").write_text("{bad json") + + def expand_side_effect(self): + local_path = str(self) + if "known_marketplaces.json" in local_path: + return plugins_dir / "known_marketplaces.json" + return tmp_path / "nonexistent" + + with ( + patch.object(Path, "expanduser", expand_side_effect), + patch.dict("os.environ", {}, clear=True), + ): + choices = _discover_sync_choices() + + # Should return without marketplace items, no exception + assert not any(c.key.startswith("marketplace:") for c in choices) + + @patch.dict("os.environ", {"GITHUB_TOKEN": "ghp_test123"}) + def test_github_token_included(self, tmp_path): + """GitHub token choice appears when GITHUB_TOKEN is set with valid prefix.""" + + def expand_side_effect(self): + return tmp_path / "nonexistent" + + with patch.object(Path, "expanduser", expand_side_effect): + choices = _discover_sync_choices() + + keys = [c.key for c in choices] + assert "github_token" in keys + + @patch.dict("os.environ", {"GITHUB_TOKEN": "not-a-token"}) + def test_github_token_excluded_invalid_prefix(self, tmp_path): + """GitHub token choice excluded when prefix is invalid.""" + + def expand_side_effect(self): + return tmp_path / "nonexistent" + + with patch.object(Path, "expanduser", expand_side_effect): + choices = _discover_sync_choices() + + keys = [c.key for c in choices] + assert "github_token" not in keys + + def test_skips_items_where_source_missing(self, tmp_path): + """Skips CLAUDE.md and settings when local files don't exist.""" + + def expand_side_effect(self): + return tmp_path / "nonexistent" + + with ( + patch.object(Path, "expanduser", expand_side_effect), + patch.dict("os.environ", {}, clear=True), + ): + choices = _discover_sync_choices() + + keys = [c.key for c in choices] + assert "claude_md" not in keys + assert "settings" not in keys + + +class TestPromptSyncSelection: + """Tests for _prompt_sync_selection function.""" + + def test_all_returns_all_keys(self): + """'all' input returns all choice keys.""" + choices = [ + SyncChoice("A", "key_a"), + SyncChoice("B", "key_b"), + ] + with patch("dropkit.main.Prompt.ask", return_value="all"): + result = _prompt_sync_selection(choices) + assert result == {"key_a", "key_b"} + + def test_none_returns_empty_set(self): + """'none' input returns empty set.""" + choices = [SyncChoice("A", "key_a")] + with patch("dropkit.main.Prompt.ask", return_value="none"): + result = _prompt_sync_selection(choices) + assert result == set() + + def test_comma_separated_numbers(self): + """Comma-separated numbers select corresponding items.""" + choices = [ + SyncChoice("A", "key_a"), + SyncChoice("B", "key_b"), + SyncChoice("C", "key_c"), + ] + with patch("dropkit.main.Prompt.ask", return_value="1,3"): + result = _prompt_sync_selection(choices) + assert result == {"key_a", "key_c"} + + def test_invalid_input_exits(self): + """Non-numeric input causes typer.Exit(1).""" + choices = [SyncChoice("A", "key_a")] + with ( + patch("dropkit.main.Prompt.ask", return_value="abc"), + pytest.raises(typer.Exit) as exc_info, + ): + _prompt_sync_selection(choices) + assert exc_info.value.exit_code == 1 + + def test_out_of_range_exits_with_error(self): + """Numbers out of range cause typer.Exit(1).""" + choices = [SyncChoice("A", "key_a")] + with ( + patch("dropkit.main.Prompt.ask", return_value="1,99"), + pytest.raises(typer.Exit) as exc_info, + ): + _prompt_sync_selection(choices) + assert exc_info.value.exit_code == 1 + + def test_empty_choices_returns_empty(self): + """Empty choices list returns empty set without prompting.""" + result = _prompt_sync_selection([]) + assert result == set() + + def test_default_is_all(self): + """Default answer is 'all' (just pressing enter).""" + choices = [SyncChoice("A", "key_a")] + with patch("dropkit.main.Prompt.ask", return_value="all") as mock_ask: + _prompt_sync_selection(choices) + assert mock_ask.call_args[1]["default"] == "all" + + +# Shared fixture for _sync_settings tests that need a local file +def _make_expand_side_effect(claude_md_path, fallback_dir): + """Create an expanduser side effect that resolves CLAUDE.md and falls back for others.""" + + def expand_side_effect(self): + local_path = str(self) + if "CLAUDE.md" in local_path: + return claude_md_path + return fallback_dir / "nonexistent" + + return expand_side_effect + + +class TestSanitizeSettings: + """Tests for _sanitize_settings pure function.""" + + def test_strips_sensitive_keys(self): + """Sensitive keys like permissions, env, hooks are removed.""" + data = { + "model": "opus", + "permissions": {"/path/to/project": {"allow": ["Read"]}}, + "env": {"SECRET_KEY": "abc123"}, + "hooks": {"pre-commit": "/local/script.sh"}, + "apiKeyHelper": "/usr/local/bin/get-key", + "sandbox": {"allowedDomains": ["internal.corp"]}, + "statusLine": {"command": "/local/status.sh"}, + } + result = _sanitize_settings(data) + assert "model" in result + for key in ("permissions", "env", "hooks", "apiKeyHelper", "sandbox", "statusLine"): + assert key not in result + + def test_preserves_allowlisted_keys(self): + """All allowlisted keys pass through unchanged.""" + data = { + "model": "opus", + "enabledPlugins": ["plugin-a"], + "language": "en", + "effortLevel": "high", + "fastMode": True, + } + result = _sanitize_settings(data) + assert result == data + + def test_empty_dict(self): + """Empty input returns empty output.""" + assert _sanitize_settings({}) == {} + + def test_does_not_mutate_input(self): + """Original dict is unchanged after sanitization.""" + data = {"model": "opus", "env": {"SECRET": "value"}} + original = data.copy() + _sanitize_settings(data) + assert data == original + + def test_unknown_keys_dropped(self): + """Keys not in the allowlist are excluded.""" + data = {"model": "opus", "some_future_key": "value", "anotherNewThing": [1, 2]} + result = _sanitize_settings(data) + assert result == {"model": "opus"} + + def test_all_safe_keys_accepted(self): + """Every key in _SETTINGS_SAFE_KEYS passes through.""" + data = {k: f"value-{k}" for k in _SETTINGS_SAFE_KEYS} + result = _sanitize_settings(data) + assert set(result.keys()) == _SETTINGS_SAFE_KEYS + + @pytest.mark.parametrize( + "key", + [ + "alwaysThinkingEnabled", + "teammateMode", + "skipWebFetchPreflight", + "attribution", + "includeCoAuthoredBy", + "skippedMarketplaces", + "skippedPlugins", + "pluginConfigs", + "companyAnnouncements", + ], + ) + def test_new_safe_keys_preserved(self, key): + """Each newly added safe key passes through sanitization.""" + data = {key: "test-value"} + result = _sanitize_settings(data) + assert result[key] == "test-value" + + def test_strips_file_marketplace_sources(self): + """extraKnownMarketplaces entries with type 'file' are stripped.""" + data = { + "extraKnownMarketplaces": { + "local-mp": {"type": "file", "path": "/Users/brad/marketplace"}, + }, + } + result = _sanitize_settings(data) + assert result["extraKnownMarketplaces"] == {} + + def test_strips_directory_marketplace_sources(self): + """extraKnownMarketplaces entries with type 'directory' are stripped.""" + data = { + "extraKnownMarketplaces": { + "dir-mp": {"type": "directory", "path": "/Users/brad/marketplaces"}, + }, + } + result = _sanitize_settings(data) + assert result["extraKnownMarketplaces"] == {} + + @pytest.mark.parametrize("source_type", ["url", "github", "git", "npm"]) + def test_preserves_remote_marketplace_sources(self, source_type): + """extraKnownMarketplaces entries with remote types are preserved.""" + data = { + "extraKnownMarketplaces": { + "remote-mp": {"type": source_type, "url": "https://example.com/mp"}, + }, + } + result = _sanitize_settings(data) + assert "remote-mp" in result["extraKnownMarketplaces"] + + def test_mixed_marketplace_sources(self): + """Local-path sources stripped while remote sources kept.""" + data = { + "extraKnownMarketplaces": { + "local-file": {"type": "file", "path": "/Users/brad/mp"}, + "local-dir": {"type": "directory", "path": "/Users/brad/mps"}, + "remote-url": {"type": "url", "url": "https://example.com/mp"}, + "remote-git": {"type": "github", "repo": "org/repo"}, + }, + } + result = _sanitize_settings(data) + assert "local-file" not in result["extraKnownMarketplaces"] + assert "local-dir" not in result["extraKnownMarketplaces"] + assert "remote-url" in result["extraKnownMarketplaces"] + assert "remote-git" in result["extraKnownMarketplaces"] + + +class TestSanitizeKnownMarketplaces: + """Tests for _sanitize_known_marketplaces pure function.""" + + def test_rewrites_install_location(self): + """Rewrites installLocation paths from local to remote home.""" + data = { + "marketplace-a": { + "installLocation": "/Users/brad/.claude/plugins/marketplaces/marketplace-a", + "url": "https://example.com/repo", + }, + } + result = _sanitize_known_marketplaces(data, "/Users/brad", "/home/brad") + assert result["marketplace-a"]["installLocation"] == ( + "/home/brad/.claude/plugins/marketplaces/marketplace-a" + ) + assert result["marketplace-a"]["url"] == "https://example.com/repo" + + def test_empty_dict(self): + """Empty input returns empty output.""" + assert _sanitize_known_marketplaces({}, "/Users/brad", "/home/brad") == {} + + def test_leaves_entries_without_install_location(self): + """Entries without installLocation are preserved unchanged.""" + data = {"marketplace-b": {"url": "https://example.com/repo"}} + result = _sanitize_known_marketplaces(data, "/Users/brad", "/home/brad") + assert result == data + + def test_does_not_mutate_input(self): + """Original dict is unchanged after sanitization.""" + data = { + "m": {"installLocation": "/Users/brad/.claude/plugins/marketplaces/m"}, + } + original_loc = data["m"]["installLocation"] + _sanitize_known_marketplaces(data, "/Users/brad", "/home/brad") + assert data["m"]["installLocation"] == original_loc + + def test_marketplace_filter_includes_only_matching(self): + """marketplace_filter restricts output to matching keys.""" + data = { + "mp-a": {"installLocation": "/Users/brad/.claude/plugins/marketplaces/mp-a"}, + "mp-b": {"installLocation": "/Users/brad/.claude/plugins/marketplaces/mp-b"}, + "mp-c": {"url": "https://example.com"}, + } + result = _sanitize_known_marketplaces( + data, "/Users/brad", "/home/brad", marketplace_filter={"mp-a", "mp-c"} + ) + assert "mp-a" in result + assert "mp-b" not in result + assert "mp-c" in result + + def test_marketplace_filter_none_includes_all(self): + """marketplace_filter=None includes everything (default).""" + data = {"mp-a": {}, "mp-b": {}} + result = _sanitize_known_marketplaces( + data, "/Users/brad", "/home/brad", marketplace_filter=None + ) + assert set(result.keys()) == {"mp-a", "mp-b"} + + +class TestSanitizeInstalledPlugins: + """Tests for _sanitize_installed_plugins pure function (v2 format).""" + + def test_rewrites_install_path_for_user_scope(self): + """Rewrites installPath for user-scope entries.""" + data = { + "version": 2, + "plugins": { + "plugin-a@source": [ + { + "scope": "user", + "installPath": "/Users/brad/.claude/plugins/cache/plugin-a", + "version": "1.0.0", + }, + ], + }, + } + result = _sanitize_installed_plugins(data, "/Users/brad", "/home/brad") + entry = result["plugins"]["plugin-a@source"][0] + assert entry["installPath"] == "/home/brad/.claude/plugins/cache/plugin-a" + assert entry["version"] == "1.0.0" + + def test_strips_local_scope_entries(self): + """Local-scope entries are removed entirely.""" + data = { + "version": 2, + "plugins": { + "plugin-local@source": [ + { + "scope": "local", + "installPath": "/Users/brad/project/.claude/plugins/cache/plugin-local", + "projectPath": "/Users/brad/project", + }, + ], + "plugin-user@source": [ + { + "scope": "user", + "installPath": "/Users/brad/.claude/plugins/cache/plugin-user", + }, + ], + }, + } + result = _sanitize_installed_plugins(data, "/Users/brad", "/home/brad") + assert "plugin-local@source" not in result["plugins"] + assert "plugin-user@source" in result["plugins"] + + def test_preserves_version_field(self): + """Entry version field is preserved in output.""" + data = { + "version": 2, + "plugins": { + "p@source": [ + { + "scope": "user", + "installPath": "/Users/brad/.claude/plugins/cache/p", + "version": "2.3.1", + }, + ], + }, + } + result = _sanitize_installed_plugins(data, "/Users/brad", "/home/brad") + assert result["plugins"]["p@source"][0]["version"] == "2.3.1" + + def test_empty_plugins(self): + """Empty plugins dict returns empty plugins.""" + data = {"version": 2, "plugins": {}} + result = _sanitize_installed_plugins(data, "/Users/brad", "/home/brad") + assert result == {"version": 2, "plugins": {}} + + def test_missing_plugins_key(self): + """Missing plugins key is handled gracefully.""" + data = {"version": 2} + result = _sanitize_installed_plugins(data, "/Users/brad", "/home/brad") + assert result == {"version": 2, "plugins": {}} + + def test_does_not_mutate_input(self): + """Original dict is unchanged after sanitization.""" + data = { + "version": 2, + "plugins": { + "p@source": [ + { + "scope": "user", + "installPath": "/Users/brad/.claude/plugins/cache/p", + }, + ], + }, + } + original_path = data["plugins"]["p@source"][0]["installPath"] + _sanitize_installed_plugins(data, "/Users/brad", "/home/brad") + assert data["plugins"]["p@source"][0]["installPath"] == original_path + + def test_entries_without_scope_are_kept(self): + """Entries without a scope field are kept and rewritten.""" + data = { + "version": 2, + "plugins": { + "p@source": [ + {"installPath": "/Users/brad/.claude/plugins/cache/p"}, + ], + }, + } + result = _sanitize_installed_plugins(data, "/Users/brad", "/home/brad") + assert "p@source" in result["plugins"] + assert ( + result["plugins"]["p@source"][0]["installPath"] == "/home/brad/.claude/plugins/cache/p" + ) + + def test_v2_wrapper_preserved(self): + """Top-level version and plugins wrapper is preserved.""" + data = { + "version": 2, + "plugins": { + "plugin-a@source": [ + { + "scope": "user", + "installPath": "/Users/brad/.claude/plugins/cache/plugin-a", + }, + ], + }, + } + result = _sanitize_installed_plugins(data, "/Users/brad", "/home/brad") + assert result["version"] == 2 + assert "plugins" in result + + def test_mixed_scopes_in_same_plugin(self): + """A plugin with both local and user entries keeps only user entries.""" + data = { + "version": 2, + "plugins": { + "plugin-a@source": [ + { + "scope": "local", + "installPath": "/Users/brad/project/.claude/plugins/cache/a", + }, + { + "scope": "user", + "installPath": "/Users/brad/.claude/plugins/cache/a", + }, + ], + }, + } + result = _sanitize_installed_plugins(data, "/Users/brad", "/home/brad") + entries = result["plugins"]["plugin-a@source"] + assert len(entries) == 1 + assert entries[0]["scope"] == "user" + assert entries[0]["installPath"] == "/home/brad/.claude/plugins/cache/a" + + def test_marketplace_filter_includes_only_matching(self): + """marketplace_filter restricts to plugins from matching marketplaces.""" + data = { + "version": 2, + "plugins": { + "plugin-a@mp-a": [ + {"scope": "user", "installPath": "/Users/brad/.claude/plugins/cache/a"}, + ], + "plugin-b@mp-b": [ + {"scope": "user", "installPath": "/Users/brad/.claude/plugins/cache/b"}, + ], + }, + } + result = _sanitize_installed_plugins( + data, "/Users/brad", "/home/brad", marketplace_filter={"mp-a"} + ) + assert "plugin-a@mp-a" in result["plugins"] + assert "plugin-b@mp-b" not in result["plugins"] + + def test_marketplace_filter_none_includes_all(self): + """marketplace_filter=None includes all plugins (default).""" + data = { + "version": 2, + "plugins": { + "plugin-a@mp-a": [ + {"scope": "user", "installPath": "/Users/brad/.claude/plugins/cache/a"}, + ], + "plugin-b@mp-b": [ + {"scope": "user", "installPath": "/Users/brad/.claude/plugins/cache/b"}, + ], + }, + } + result = _sanitize_installed_plugins( + data, "/Users/brad", "/home/brad", marketplace_filter=None + ) + assert "plugin-a@mp-a" in result["plugins"] + assert "plugin-b@mp-b" in result["plugins"] + + +class TestSyncSettings: + """Tests for _sync_settings function.""" + + @patch("dropkit.main.subprocess.run") + def test_skips_missing_paths(self, mock_run, tmp_path): + """Skips paths that don't exist locally.""" + with patch.object(Path, "expanduser", return_value=tmp_path / "nonexistent"): + _sync_settings("dropkit.test", verbose=True, remote_home="/home/testuser") + # No rsync calls when all paths missing + rsync_calls = [c for c in mock_run.call_args_list if "rsync" in str(c)] + assert len(rsync_calls) == 0 + + @patch("dropkit.main.subprocess.run") + def test_sync_existing_file(self, mock_run, tmp_path): + """Syncs a file that exists locally and targets correct remote path.""" + claude_md = tmp_path / ".claude" / "CLAUDE.md" + claude_md.parent.mkdir(parents=True) + claude_md.write_text("# test") + + mock_run.return_value = MagicMock(returncode=0) + + with patch.object(Path, "expanduser", _make_expand_side_effect(claude_md, tmp_path)): + _sync_settings("dropkit.test", verbose=False, remote_home="/home/testuser") + + # Should have at least one rsync call targeting the correct remote + rsync_calls = [c for c in mock_run.call_args_list if "rsync" in str(c[0][0])] + assert len(rsync_calls) >= 1 + rsync_cmd = rsync_calls[0][0][0] + assert "dropkit.test:.claude/CLAUDE.md" in rsync_cmd[-1] + + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_mkdir_failure_skips_rsync(self, mock_run, mock_console, tmp_path): + """When mkdir fails, rsync is skipped and failure is recorded.""" + claude_md = tmp_path / ".claude" / "CLAUDE.md" + claude_md.parent.mkdir(parents=True) + claude_md.write_text("# test") + + # mkdir fails + mock_run.return_value = MagicMock(returncode=1, stderr=b"permission denied") + + with patch.object(Path, "expanduser", _make_expand_side_effect(claude_md, tmp_path)): + _sync_settings("dropkit.test", verbose=False, remote_home="/home/testuser") + + # Should report FAILED + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("FAILED" in c for c in print_calls) + + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_partial_sync_failure(self, mock_run, mock_console, tmp_path): + """Reports failure when rsync fails for one path.""" + claude_md = tmp_path / ".claude" / "CLAUDE.md" + claude_md.parent.mkdir(parents=True) + claude_md.write_text("# test") + + # mkdir succeeds, rsync fails + mock_run.side_effect = [ + MagicMock(returncode=0, stderr=b""), # mkdir + MagicMock(returncode=1, stderr=b"rsync error"), # rsync + ] + + with patch.object(Path, "expanduser", _make_expand_side_effect(claude_md, tmp_path)): + _sync_settings("dropkit.test", verbose=False, remote_home="/home/testuser") + + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("FAILED" in c for c in print_calls) + + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_rsync_failure_includes_reason(self, mock_run, mock_console, tmp_path): + """Rsync failure includes the error reason even without --verbose.""" + claude_md = tmp_path / ".claude" / "CLAUDE.md" + claude_md.parent.mkdir(parents=True) + claude_md.write_text("# test") + + mock_run.side_effect = [ + MagicMock(returncode=0, stderr=b""), # mkdir + MagicMock(returncode=1, stderr=b"rsync: connection unexpectedly closed"), # rsync + ] + + with patch.object(Path, "expanduser", _make_expand_side_effect(claude_md, tmp_path)): + _sync_settings("dropkit.test", verbose=False, remote_home="/home/testuser") + + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("connection unexpectedly closed" in c for c in print_calls) + + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_no_local_settings_shows_skipped(self, mock_run, mock_console, tmp_path): + """Shows 'skipped' when no local settings files exist.""" + with patch.object(Path, "expanduser", return_value=tmp_path / "nonexistent"): + _sync_settings("dropkit.test", verbose=False, remote_home="/home/testuser") + + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("skipped" in c and "no local settings" in c for c in print_calls) + + @patch("dropkit.main.subprocess.run") + def test_timeout_caught_gracefully(self, mock_run, tmp_path): + """Timeout during sync is caught, not raised.""" + claude_md = tmp_path / ".claude" / "CLAUDE.md" + claude_md.parent.mkdir(parents=True) + claude_md.write_text("# test") + + mock_run.side_effect = subprocess.TimeoutExpired("ssh", 30) + + with patch.object(Path, "expanduser", _make_expand_side_effect(claude_md, tmp_path)): + # Should not raise + _sync_settings("dropkit.test", verbose=False, remote_home="/home/testuser") + + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_rsync_not_found_caught(self, mock_run, mock_console, tmp_path): + """FileNotFoundError when rsync is not installed is caught gracefully.""" + claude_md = tmp_path / ".claude" / "CLAUDE.md" + claude_md.parent.mkdir(parents=True) + claude_md.write_text("# test") + + mock_run.side_effect = [ + MagicMock(returncode=0, stderr=b""), # mkdir + FileNotFoundError("rsync not found"), # rsync binary missing + ] + + with patch.object(Path, "expanduser", _make_expand_side_effect(claude_md, tmp_path)): + _sync_settings("dropkit.test", verbose=False, remote_home="/home/testuser") + + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("FAILED" in c for c in print_calls) + + @patch("dropkit.main.subprocess.run") + def test_directory_trailing_slash(self, mock_run, tmp_path): + """Directories get trailing slash in rsync source.""" + plugins_dir = tmp_path / ".claude" / "plugins" + plugins_dir.mkdir(parents=True) + + def expand_side_effect(self): + local_path = str(self) + if "plugins" in local_path: + return plugins_dir + return tmp_path / "nonexistent" + + mock_run.return_value = MagicMock(returncode=0) + + with patch.object(Path, "expanduser", expand_side_effect): + _sync_settings("dropkit.test", verbose=False, remote_home="/home/testuser") + + rsync_calls = [c for c in mock_run.call_args_list if "rsync" in str(c[0][0])] + assert len(rsync_calls) >= 1 + rsync_cmd = rsync_calls[0][0][0] + # Source should end with / for directories + source_arg = [ + a + for a in rsync_cmd + if "plugins" in a and "dropkit.test" not in a and not a.startswith("--exclude=") + ] + assert len(source_arg) == 1 + assert source_arg[0].endswith("/") + + def test_discover_returns_expected_base_items(self, tmp_path): + """_discover_sync_choices returns CLAUDE.md and Settings when files exist.""" + claude_dir = tmp_path / ".claude" + claude_dir.mkdir() + (claude_dir / "CLAUDE.md").write_text("# test") + (claude_dir / "settings.json").write_text("{}") + + def expand_side_effect(self): + local_path = str(self) + if "CLAUDE.md" in local_path: + return claude_dir / "CLAUDE.md" + if "settings.json" in local_path: + return claude_dir / "settings.json" + return tmp_path / "nonexistent" + + with ( + patch.object(Path, "expanduser", expand_side_effect), + patch.dict("os.environ", {}, clear=True), + ): + choices = _discover_sync_choices() + + keys = [c.key for c in choices] + assert "claude_md" in keys + assert "settings" in keys + + @patch("dropkit.main.subprocess.run") + def test_selected_subset_only_syncs_chosen(self, mock_run, tmp_path): + """When selected={'claude_md'}, only CLAUDE.md is synced.""" + claude_md = tmp_path / ".claude" / "CLAUDE.md" + claude_md.parent.mkdir(parents=True) + claude_md.write_text("# test") + + mock_run.return_value = MagicMock(returncode=0) + + with patch.object(Path, "expanduser", _make_expand_side_effect(claude_md, tmp_path)): + _sync_settings( + "dropkit.test", + verbose=False, + remote_home="/home/testuser", + selected={"claude_md"}, + ) + + rsync_calls = [c for c in mock_run.call_args_list if "rsync" in str(c[0][0])] + assert len(rsync_calls) >= 1 + # All rsync calls should be for CLAUDE.md only + for call in rsync_calls: + cmd_str = " ".join(call[0][0]) + assert "CLAUDE.md" in cmd_str + + @patch("dropkit.main.subprocess.run") + def test_selected_none_syncs_all(self, mock_run, tmp_path): + """When selected=None, all available paths are synced.""" + claude_md = tmp_path / ".claude" / "CLAUDE.md" + claude_md.parent.mkdir(parents=True) + claude_md.write_text("# test") + + mock_run.return_value = MagicMock(returncode=0) + + with patch.object(Path, "expanduser", _make_expand_side_effect(claude_md, tmp_path)): + _sync_settings( + "dropkit.test", + verbose=False, + remote_home="/home/testuser", + selected=None, + ) + + rsync_calls = [c for c in mock_run.call_args_list if "rsync" in str(c[0][0])] + assert len(rsync_calls) >= 1 + + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_empty_selected_shows_skipped(self, mock_run, mock_console, tmp_path): + """When selected=set(), shows 'skipped' (no local settings found).""" + with patch.object(Path, "expanduser", return_value=tmp_path / "nonexistent"): + _sync_settings( + "dropkit.test", + verbose=False, + remote_home="/home/testuser", + selected=set(), + ) + + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("skipped" in c for c in print_calls) + + @patch("dropkit.main.subprocess.run") + def test_settings_json_is_sanitized(self, mock_run, tmp_path): + """Rsync source for settings.json is a temp file with only allowlisted keys.""" + settings = tmp_path / ".claude" / "settings.json" + settings.parent.mkdir(parents=True) + settings.write_text( + json.dumps( + { + "model": "opus", + "permissions": {"/local/path": {"allow": ["Read"]}}, + "env": {"SECRET": "leaked"}, + } + ) + ) + + mock_run.return_value = MagicMock(returncode=0) + + def expand_side_effect(self): + local_path = str(self) + if "settings.json" in local_path: + return settings + return tmp_path / "nonexistent" + + with patch.object(Path, "expanduser", expand_side_effect): + _sync_settings("dropkit.test", verbose=False, remote_home="/home/testuser") + + # Find the rsync call for settings.json + rsync_calls = [ + c + for c in mock_run.call_args_list + if "rsync" in str(c[0][0]) and "settings.json" in str(c[0][0]) + ] + assert len(rsync_calls) == 1 + rsync_cmd = rsync_calls[0][0][0] + # Source should NOT be the original file path + source_arg = rsync_cmd[-2] # second-to-last arg is source + assert str(settings) not in source_arg + # The temp file content should only contain allowlisted keys + # (temp file is cleaned up, but we can check via the rsync source path) + assert source_arg != str(settings) + + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_settings_sanitize_failure_skips(self, mock_run, mock_console, tmp_path): + """Invalid JSON in settings.json is skipped entirely.""" + settings = tmp_path / ".claude" / "settings.json" + claude_md = tmp_path / ".claude" / "CLAUDE.md" + settings.parent.mkdir(parents=True) + settings.write_text("NOT VALID JSON {{{") + claude_md.write_text("# test") + + mock_run.return_value = MagicMock(returncode=0) + + def expand_side_effect(self): + local_path = str(self) + if "settings.json" in local_path: + return settings + if "CLAUDE.md" in local_path: + return claude_md + return tmp_path / "nonexistent" + + with patch.object(Path, "expanduser", expand_side_effect): + _sync_settings("dropkit.test", verbose=True, remote_home="/home/testuser") + + # settings.json should be skipped, but CLAUDE.md should still sync + rsync_calls = [c for c in mock_run.call_args_list if "rsync" in str(c[0][0])] + for call in rsync_calls: + cmd = call[0][0] + # No rsync call should reference settings.json + assert not any("settings.json" in str(arg) for arg in cmd) + + @patch("dropkit.main.subprocess.run") + def test_temp_file_cleaned_up_on_success(self, mock_run, tmp_path): + """Temp file is deleted after successful rsync.""" + settings = tmp_path / ".claude" / "settings.json" + settings.parent.mkdir(parents=True) + settings.write_text(json.dumps({"model": "opus"})) + + created_temps: list[str] = [] + original_mkstemp = __import__("tempfile").mkstemp + + def tracking_mkstemp(**kwargs): + fd, path = original_mkstemp(**kwargs) + created_temps.append(path) + return fd, path + + mock_run.return_value = MagicMock(returncode=0) + + def expand_side_effect(self): + local_path = str(self) + if "settings.json" in local_path: + return settings + return tmp_path / "nonexistent" + + with ( + patch.object(Path, "expanduser", expand_side_effect), + patch("dropkit.main.tempfile.mkstemp", side_effect=tracking_mkstemp), + ): + _sync_settings("dropkit.test", verbose=False, remote_home="/home/testuser") + + assert len(created_temps) == 1 + assert not Path(created_temps[0]).exists() + + @patch("dropkit.main.subprocess.run") + def test_temp_file_cleaned_up_on_rsync_failure(self, mock_run, tmp_path): + """Temp file is deleted even when rsync fails.""" + settings = tmp_path / ".claude" / "settings.json" + settings.parent.mkdir(parents=True) + settings.write_text(json.dumps({"model": "opus"})) + + created_temps: list[str] = [] + original_mkstemp = __import__("tempfile").mkstemp + + def tracking_mkstemp(**kwargs): + fd, path = original_mkstemp(**kwargs) + created_temps.append(path) + return fd, path + + # mkdir succeeds, rsync fails + mock_run.side_effect = [ + MagicMock(returncode=0, stderr=b""), # mkdir + MagicMock(returncode=1, stderr=b"rsync error"), # rsync + ] + + def expand_side_effect(self): + local_path = str(self) + if "settings.json" in local_path: + return settings + return tmp_path / "nonexistent" + + with ( + patch.object(Path, "expanduser", expand_side_effect), + patch("dropkit.main.tempfile.mkstemp", side_effect=tracking_mkstemp), + ): + _sync_settings("dropkit.test", verbose=False, remote_home="/home/testuser") + + assert len(created_temps) == 1 + assert not Path(created_temps[0]).exists() + + +class TestOpenAuthSession: + """Tests for _open_auth_session function.""" + + @patch("dropkit.main.subprocess.run") + def test_runs_claude_login_via_ssh(self, mock_run): + """Runs claude /login on the droplet via SSH.""" + mock_run.return_value = MagicMock(returncode=0) + result = _open_auth_session("dropkit.test", verbose=False) + + assert result is True + mock_run.assert_called_once() + cmd = mock_run.call_args[0][0] + assert cmd[0] == "ssh" + assert "dropkit.test" in cmd + # Should run claude /login, not open an interactive shell + cmd_str = " ".join(cmd) + assert "claude /login" in cmd_str + + @patch("dropkit.main.subprocess.run") + def test_uses_ephemeral_temp_directory(self, mock_run): + """Creates a temp directory and cleans it up after auth.""" + mock_run.return_value = MagicMock(returncode=0) + _open_auth_session("dropkit.test", verbose=False) + + cmd = mock_run.call_args[0][0] + cmd_str = " ".join(cmd) + assert "mktemp -d" in cmd_str + assert 'rm -rf "$dir"' in cmd_str + + @patch("dropkit.main.subprocess.run") + def test_uses_bash_login_shell(self, mock_run): + """Uses bash -lc to bypass zsh/p10k wizard.""" + mock_run.return_value = MagicMock(returncode=0) + _open_auth_session("dropkit.test", verbose=False) + + cmd = mock_run.call_args[0][0] + cmd_str = " ".join(cmd) + assert "bash -lc" in cmd_str + + @patch("dropkit.main.subprocess.run") + def test_allocates_pseudo_tty(self, mock_run): + """SSH command includes -t for pseudo-TTY so claude /login can display its URL.""" + mock_run.return_value = MagicMock(returncode=0) + _open_auth_session("dropkit.test", verbose=False) + + cmd = mock_run.call_args[0][0] + assert "-t" in cmd + + @patch("dropkit.main.subprocess.run") + def test_no_port_forwarding(self, mock_run): + """No port forwarding args in SSH command.""" + mock_run.return_value = MagicMock(returncode=0) + _open_auth_session("dropkit.test", verbose=False) + + cmd = mock_run.call_args[0][0] + assert "-L" not in cmd + + @patch("dropkit.main.subprocess.run") + def test_does_not_capture_output(self, mock_run): + """Interactive session must not capture output.""" + mock_run.return_value = MagicMock(returncode=0) + _open_auth_session("dropkit.test", verbose=False) + kwargs = mock_run.call_args[1] + assert "capture_output" not in kwargs or kwargs["capture_output"] is False + + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_verbose_shows_command(self, mock_run, mock_console): + """Verbose mode prints the SSH command.""" + mock_run.return_value = MagicMock(returncode=0) + _open_auth_session("dropkit.test", verbose=True) + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("ssh" in c and "claude /login" in c for c in print_calls) + + @patch("dropkit.main.subprocess.run") + def test_os_error_raises_exit(self, mock_run): + """OSError (e.g., ssh not found) raises typer.Exit.""" + mock_run.side_effect = FileNotFoundError("ssh not found") + with pytest.raises(typer.Exit): + _open_auth_session("dropkit.test", verbose=False) + + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_ssh_failure_exit_255(self, mock_run, mock_console): + """Exit code 255 (SSH connection failure) returns False with clear message.""" + mock_run.return_value = MagicMock(returncode=255) + result = _open_auth_session("dropkit.test", verbose=False) + assert result is False + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("SSH connection failed" in c for c in print_calls) + + @patch("dropkit.main.subprocess.run") + def test_returns_true_on_success(self, mock_run): + """Returns True when session exits cleanly.""" + mock_run.return_value = MagicMock(returncode=0) + assert _open_auth_session("dropkit.test", verbose=False) is True + + @patch("dropkit.main.subprocess.run") + def test_nonzero_exit_returns_true(self, mock_run): + """Non-zero non-255 exit is normal (e.g. user cancelled), returns True.""" + mock_run.return_value = MagicMock(returncode=1) + assert _open_auth_session("dropkit.test", verbose=False) is True + + @patch("dropkit.main.console") + @patch("dropkit.main.subprocess.run") + def test_keyboard_interrupt_warns_about_auth(self, mock_run, mock_console): + """Ctrl-C warns that auth may be incomplete and exits with code 130.""" + mock_run.side_effect = KeyboardInterrupt() + with pytest.raises(typer.Exit) as exc_info: + _open_auth_session("dropkit.test", verbose=False) + assert exc_info.value.exit_code == 130 + print_calls = [str(c) for c in mock_console.print.call_args_list] + assert any("re-run" in c.lower() for c in print_calls) + + +class TestSetupClaude: + """Tests for the setup_claude command orchestrator.""" + + @patch("dropkit.main._open_auth_session") + @patch("dropkit.main._auth_github") + @patch("dropkit.main._install_claude_code", return_value=False) + @patch("dropkit.main.find_user_droplet", return_value=({"name": "test"}, "testuser")) + @patch("dropkit.main.load_config_and_api", return_value=(MagicMock(), MagicMock())) + def test_exits_on_install_failure( + self, mock_config, mock_find, mock_install, mock_auth, mock_session + ): + """Aborts with exit code 1 when install fails.""" + with pytest.raises(typer.Exit) as exc_info: + setup_claude(droplet_name="test", sync_all=False, verbose=False) + assert exc_info.value.exit_code == 1 + mock_auth.assert_not_called() + mock_session.assert_not_called() + + @patch("dropkit.main._open_auth_session", return_value=True) + @patch("dropkit.main._sync_settings") + @patch("dropkit.main._auth_github") + @patch("dropkit.main._install_claude_code", return_value=True) + @patch("dropkit.main.find_user_droplet", return_value=({"name": "test"}, "testuser")) + @patch("dropkit.main.load_config_and_api", return_value=(MagicMock(), MagicMock())) + def test_sync_all_runs_everything( + self, mock_config, mock_find, mock_install, mock_auth, mock_sync, mock_session + ): + """--sync-all runs _auth_github and _sync_settings with selected=None.""" + setup_claude(droplet_name="test", sync_all=True, verbose=False) + mock_auth.assert_called_once() + mock_sync.assert_called_once_with( + "dropkit.test", False, remote_home="/home/testuser", selected=None + ) + mock_session.assert_called_once() + + @patch("dropkit.main._open_auth_session", return_value=True) + @patch("dropkit.main._sync_settings") + @patch("dropkit.main._auth_github") + @patch("dropkit.main._prompt_sync_selection", return_value=set()) + @patch("dropkit.main._discover_sync_choices", return_value=[]) + @patch("dropkit.main._install_claude_code", return_value=True) + @patch("dropkit.main.find_user_droplet", return_value=({"name": "test"}, "testuser")) + @patch("dropkit.main.load_config_and_api", return_value=(MagicMock(), MagicMock())) + def test_empty_selection_skips_sync( + self, + mock_config, + mock_find, + mock_install, + mock_discover, + mock_prompt, + mock_auth, + mock_sync, + mock_session, + ): + """Empty selection skips both auth and sync.""" + setup_claude(droplet_name="test", sync_all=False, verbose=False) + mock_auth.assert_not_called() + mock_sync.assert_not_called() + mock_session.assert_called_once() + + @patch("dropkit.main._open_auth_session", return_value=True) + @patch("dropkit.main._sync_settings") + @patch("dropkit.main._auth_github") + @patch("dropkit.main._prompt_sync_selection", return_value={"github_token", "claude_md"}) + @patch( + "dropkit.main._discover_sync_choices", + return_value=[ + SyncChoice("Global CLAUDE.md", "claude_md"), + SyncChoice("GitHub token", "github_token"), + ], + ) + @patch("dropkit.main._install_claude_code", return_value=True) + @patch("dropkit.main.find_user_droplet", return_value=({"name": "test"}, "testuser")) + @patch("dropkit.main.load_config_and_api", return_value=(MagicMock(), MagicMock())) + def test_github_token_gated_by_selection( + self, + mock_config, + mock_find, + mock_install, + mock_discover, + mock_prompt, + mock_auth, + mock_sync, + mock_session, + ): + """GitHub auth runs only when github_token is selected.""" + setup_claude(droplet_name="test", sync_all=False, verbose=False) + mock_auth.assert_called_once() + # _sync_settings should be called with claude_md only (not github_token) + mock_sync.assert_called_once() + call_kwargs = mock_sync.call_args + selected = call_kwargs[1]["selected"] if "selected" in call_kwargs[1] else call_kwargs[0][3] + assert "github_token" not in selected + assert "claude_md" in selected + + @patch("dropkit.main._open_auth_session", return_value=True) + @patch("dropkit.main._sync_settings") + @patch("dropkit.main._auth_github") + @patch("dropkit.main._prompt_sync_selection", return_value={"claude_md"}) + @patch( + "dropkit.main._discover_sync_choices", + return_value=[ + SyncChoice("Global CLAUDE.md", "claude_md"), + ], + ) + @patch("dropkit.main._install_claude_code", return_value=True) + @patch("dropkit.main.find_user_droplet", return_value=({"name": "test"}, "testuser")) + @patch("dropkit.main.load_config_and_api", return_value=(MagicMock(), MagicMock())) + def test_default_behavior_calls_prompt( + self, + mock_config, + mock_find, + mock_install, + mock_discover, + mock_prompt, + mock_auth, + mock_sync, + mock_session, + ): + """Default (no --sync-all) calls discover + prompt.""" + setup_claude(droplet_name="test", sync_all=False, verbose=False) + mock_discover.assert_called_once() + mock_prompt.assert_called_once() + + @patch("dropkit.main._open_auth_session", return_value=True) + @patch("dropkit.main._auth_github") + @patch("dropkit.main._install_claude_code", return_value=True) + @patch("dropkit.main.find_user_droplet", return_value=({"name": "test"}, "testuser")) + @patch("dropkit.main.load_config_and_api", return_value=(MagicMock(), MagicMock())) + def test_always_opens_auth_session( + self, mock_config, mock_find, mock_install, mock_auth, mock_session + ): + """Always opens auth session even if GitHub auth fails.""" + mock_auth.side_effect = None # auth does not raise + setup_claude(droplet_name="test", sync_all=True, verbose=False) + mock_session.assert_called_once() + + @patch("dropkit.main._open_auth_session") + @patch("dropkit.main._install_claude_code") + @patch("dropkit.main.find_user_droplet", return_value=(None, "testuser")) + @patch("dropkit.main.load_config_and_api", return_value=(MagicMock(), MagicMock())) + def test_exits_on_droplet_not_found(self, mock_config, mock_find, mock_install, mock_session): + """Exits with code 1 when droplet not found.""" + with pytest.raises(typer.Exit) as exc_info: + setup_claude(droplet_name="nonexistent", sync_all=False, verbose=False) + assert exc_info.value.exit_code == 1 + mock_install.assert_not_called() + + @patch("dropkit.main._open_auth_session", return_value=True) + @patch("dropkit.main._sync_settings") + @patch("dropkit.main._auth_github") + @patch("dropkit.main._prompt_sync_selection", return_value={"github_token"}) + @patch( + "dropkit.main._discover_sync_choices", + return_value=[SyncChoice("GitHub token", "github_token")], + ) + @patch("dropkit.main._install_claude_code", return_value=True) + @patch("dropkit.main.find_user_droplet", return_value=({"name": "test"}, "testuser")) + @patch("dropkit.main.load_config_and_api", return_value=(MagicMock(), MagicMock())) + def test_github_token_only_does_not_sync_everything( + self, + mock_config, + mock_find, + mock_install, + mock_discover, + mock_prompt, + mock_auth, + mock_sync, + mock_session, + ): + """Selecting only github_token must NOT call _sync_settings with selected=None.""" + setup_claude(droplet_name="test", sync_all=False, verbose=False) + mock_auth.assert_called_once() + # _sync_settings should NOT be called (empty set = nothing to sync) + mock_sync.assert_not_called() + + @patch("dropkit.main._open_auth_session", return_value=False) + @patch("dropkit.main._auth_github") + @patch("dropkit.main._install_claude_code", return_value=True) + @patch("dropkit.main.find_user_droplet", return_value=({"name": "test"}, "testuser")) + @patch("dropkit.main.load_config_and_api", return_value=(MagicMock(), MagicMock())) + def test_exits_nonzero_on_ssh_connection_failure( + self, mock_config, mock_find, mock_install, mock_auth, mock_session + ): + """Exits with code 1 when SSH connection fails (exit 255).""" + with pytest.raises(typer.Exit) as exc_info: + setup_claude(droplet_name="test", sync_all=True, verbose=False) + assert exc_info.value.exit_code == 1