Skip to content
112 changes: 78 additions & 34 deletions monai/apps/nnunet/nnunetv2_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import glob
import os
import shlex
import subprocess
from typing import Any

Expand Down Expand Up @@ -486,16 +487,16 @@ def plan_and_process(
if not no_pp:
self.preprocess(c, n_proc, overwrite_plans_name, verbose)

def train_single_model(self, config: Any, fold: int, gpu_id: tuple | list | int = 0, **kwargs: Any) -> None:
def train_single_model(self, config: Any, fold: int, gpu_id: tuple | list | int | str = 0, **kwargs: Any) -> None:
"""
Run the training on a single GPU with one specified configuration provided.
Note: this will override the environment variable `CUDA_VISIBLE_DEVICES`.
Note: if CUDA_VISIBLE_DEVICES is already set and gpu_id resolves to 0, the existing value is preserved;
otherwise it is set to gpu_id.

Args:
config: configuration that should be trained. Examples: "2d", "3d_fullres", "3d_lowres".
fold: fold of the 5-fold cross-validation. Should be an int between 0 and 4.
gpu_id: an integer to select the device to use, or a tuple/list of GPU device indices used for multi-GPU
training (e.g., (0,1)). Default: 0.
gpu_id: an int, MIG UUID (str), or tuple/list of GPU indices for multi-GPU training (e.g., (0,1)). Default: 0.
kwargs: this optional parameter allows you to specify additional arguments in
``nnunetv2.run.run_training.run_training_entry``.

Expand Down Expand Up @@ -525,35 +526,71 @@ def train_single_model(self, config: Any, fold: int, gpu_id: tuple | list | int
kwargs.pop("npz")
logger.warning("please specify the `export_validation_probabilities` in the __init__ of `nnUNetV2Runner`.")

cmd = self.train_single_model_command(config, fold, gpu_id, kwargs)
run_cmd(cmd, shell=True)
cmd, env = self.train_single_model_command(config, fold, gpu_id, kwargs)
run_cmd(cmd, env=env)

def train_single_model_command(self, config, fold, gpu_id, kwargs):
if isinstance(gpu_id, (tuple, list)):
def train_single_model_command(
self, config: str, fold: int, gpu_id: int | str | tuple | list, kwargs: dict[str, Any]
) -> tuple[list[str], dict[str, str]]:
"""
Build the shell command string for training a single nnU-Net model.

Args:
config: Configuration name (e.g., "3d_fullres").
fold: Cross-validation fold index (0-4).
gpu_id: Device selector—int, str (MIG UUID), or tuple/list for multi-GPU.
kwargs: Additional CLI arguments forwarded to nnUNetv2_train.

Returns:
Tuple of (cmd, env) where cmd is a list[str] of argv entries and env is a dict[str, str]
passed to the subprocess.

Raises:
ValueError: If gpu_id is an empty tuple or list.
"""
env = os.environ.copy()
device_setting: str | None = None
num_gpus = 1
if isinstance(gpu_id, str):
device_setting = gpu_id
num_gpus = 1
elif isinstance(gpu_id, (tuple, list)):
if len(gpu_id) == 0:
raise ValueError("gpu_id tuple/list cannot be empty")
if len(gpu_id) > 1:
gpu_ids_str = ""
for _i in range(len(gpu_id)):
gpu_ids_str += f"{gpu_id[_i]},"
device_setting = f"CUDA_VISIBLE_DEVICES={gpu_ids_str[:-1]}"
else:
device_setting = f"CUDA_VISIBLE_DEVICES={gpu_id[0]}"
device_setting = ",".join(str(x) for x in gpu_id)
num_gpus = len(gpu_id)
elif len(gpu_id) == 1:
device_setting = str(gpu_id[0])
num_gpus = 1
else:
device_setting = f"CUDA_VISIBLE_DEVICES={gpu_id}"
num_gpus = 1 if isinstance(gpu_id, int) or len(gpu_id) == 1 else len(gpu_id)

cmd = (
f"{device_setting} nnUNetv2_train "
+ f"{self.dataset_name_or_id} {config} {fold} "
+ f"-tr {self.trainer_class_name} -num_gpus {num_gpus}"
)
device_setting = str(gpu_id)
num_gpus = 1
env_cuda = env.get("CUDA_VISIBLE_DEVICES")
if env_cuda is not None and device_setting == "0":
logger.info(f"Using existing environment variable CUDA_VISIBLE_DEVICES='{env_cuda}'")
device_setting = None
elif device_setting is not None:
env["CUDA_VISIBLE_DEVICES"] = device_setting

cmd = [
"nnUNetv2_train",
f"{self.dataset_name_or_id}",
f"{config}",
f"{fold}",
"-tr",
f"{self.trainer_class_name}",
"-num_gpus",
f"{num_gpus}",
]
if self.export_validation_probabilities:
cmd += " --npz"
cmd.append("--npz")
for _key, _value in kwargs.items():
if _key == "p" or _key == "pretrained_weights":
cmd += f" -{_key} {_value}"
cmd.extend([f"-{_key}", f"{_value}"])
else:
cmd += f" --{_key} {_value}"
return cmd
cmd.extend([f"--{_key}", f"{_value}"])
return cmd, env

def train(
self,
Expand Down Expand Up @@ -637,8 +674,8 @@ def train_parallel_cmd(
if _config in ensure_tuple(configs):
for _i in range(self.num_folds):
the_device = gpu_id_for_all[_index % n_devices] # type: ignore
cmd = self.train_single_model_command(_config, _i, the_device, kwargs)
all_cmds[-1][the_device].append(cmd)
cmd, env = self.train_single_model_command(_config, _i, the_device, kwargs)
all_cmds[-1][the_device].append((cmd, env))
_index += 1
return all_cmds

Expand Down Expand Up @@ -666,19 +703,21 @@ def train_parallel(
for gpu_id, gpu_cmd in cmds.items():
if not gpu_cmd:
continue
cmds_for_log = [shlex.join(cmd) for cmd, _ in gpu_cmd]
logger.info(
f"training - stage {s + 1}:\n"
f"for gpu {gpu_id}, commands: {gpu_cmd}\n"
f"for gpu {gpu_id}, commands: {cmds_for_log}\n"
f"log '.txt' inside '{os.path.join(self.nnunet_results, self.dataset_name)}'"
)
for stage in all_cmds:
processes = []
for device_id in stage:
if not stage[device_id]:
continue
cmd_str = "; ".join(stage[device_id])
cmd_str = "; ".join(shlex.join(cmd) for cmd, _ in stage[device_id])
env = stage[device_id][0][1]
logger.info(f"Current running command on GPU device {device_id}:\n{cmd_str}\n")
processes.append(subprocess.Popen(cmd_str, shell=True, stdout=subprocess.DEVNULL))
processes.append(subprocess.Popen(cmd_str, shell=True, env=env, stdout=subprocess.DEVNULL))
# finish this stage first
for p in processes:
p.wait()
Expand Down Expand Up @@ -779,7 +818,7 @@ def predict(
part_id: int = 0,
num_processes_preprocessing: int = -1,
num_processes_segmentation_export: int = -1,
gpu_id: int = 0,
gpu_id: int | str = 0,
) -> None:
"""
Use this to run inference with nnU-Net. This function is used when you want to manually specify a folder containing
Expand Down Expand Up @@ -813,9 +852,14 @@ def predict(
num_processes_preprocessing: out-of-RAM issues.
num_processes_segmentation_export: Number of processes used for segmentation export.
More is not always better. Beware of out-of-RAM issues.
gpu_id: which GPU to use for prediction.
gpu_id: GPU device index (int) or MIG UUID (str) for prediction.
If CUDA_VISIBLE_DEVICES is already set and gpu_id is 0, the existing
environment variable is preserved.
"""
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"
if "CUDA_VISIBLE_DEVICES" in os.environ and (gpu_id == 0 or gpu_id == "0"):
logger.info(f"Predict: Using existing CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']}")
else:
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"

from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor

Expand Down
Loading