diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 64db4646f6..0a0ccf868b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -12,6 +12,8 @@ import torch from torch.nn.parameter import Parameter +from torch.distributed import DeviceMesh +from torch.distributed.tensor import DTensor import transformer_engine_torch as tex from transformer_engine.common.recipe import ( @@ -44,6 +46,7 @@ set_all_rng_states, CudaRNGStatesTracker, graph_safe_rng_available, + _convert_param_to_dtensor_param, ) from transformer_engine.pytorch.jit import no_torch_dynamo from transformer_engine.pytorch.graph import is_graph_capturing @@ -298,6 +301,11 @@ class DotProductAttention(TransformerEngineBaseModule): ``ProcessGroup`` is for :attr:`cp_comm_type` of ``"p2p"``, ``"all_gather"``, and ``"a2a"``. ``List[ProcessGroup]`` is for :attr:`cp_comm_type` of ``"a2a+p2p"``, where :attr:`cp_group[0]` and :attr:`cp_group[1]` are for ``"a2a"`` and ``"p2p"`` communications respectively. + tp_mesh : Optional[DeviceMesh], default = None + A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"]. + Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP. + weight_mesh : Optional[DeviceMesh] + Not used for DotProductAttention as there are no quantized weights. cp_global_ranks : list of global rank IDs, default = None global rank IDs of GPUs that are in ``cp_group``. cp_stream : CUDA stream, default = None @@ -343,6 +351,8 @@ def __init__( softmax_scale: Optional[float] = None, softmax_type: str = "vanilla", return_max_logit: Optional[bool] = False, + tp_mesh: Optional[DeviceMesh] = None, + weight_mesh: Optional[DeviceMesh] = None, ) -> None: super().__init__() @@ -477,6 +487,10 @@ def __init__( return_max_logit=self.return_max_logit, ) + if tp_mesh is not None or weight_mesh is not None: + # Apply DeviceMesh and DTensor-related modifications. + self.set_device_mesh(tp_mesh=tp_mesh, weight_mesh=weight_mesh) + def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument """ Temporarily remove core_attention._extra_state as a missing key @@ -533,6 +547,54 @@ def custom_forward(*input_args, **input_kwargs): return hidden_states + def set_device_mesh( + self, + tp_mesh: Optional[DeviceMesh] = None, + weight_mesh: Optional[DeviceMesh] = None, + ) -> None: + """ + Set DeviceMesh(s) used for sharding weights and convert main weights into DTensor + depending on the TransformerEngine class to support FSDP-TP sharding with FSDP2. + + TransformerEngine manages tensor parallel mechanics, while DTensor offers seamless + integration with Torch DCP checkpointing. This method should only be invoked when + using DTensor parameters, e.g. when using FSDP2 or DCP. + + When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically + convert them into FSDP-TP strided or non-strided shards depending on the current + sharding dimension and factor of the DTensor. When the sharding dimension of FSDP + matches that of TP, FSDP uses a _StridedShard placement type instead of Shard. + This experimental FSDP-TP logic presides in this FSDP2 initialization function: + ``torch.distributed.fsdp._fully_shard._fsdp_param._init_sharded_param`` + + Parameters + ---------- + tp_mesh : Optional[DeviceMesh] + A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"]. + Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP. + weight_mesh : Optional[DeviceMesh] + Not used for DotProductAttention as there are no quantized weights. + """ + if tp_mesh is not None: + # Validate TP DeviceMesh / Group. Must be consistent with tp_size. + assert ( + tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(), + ( + f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) " + f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})." + ), + ) + # Set the tensor parallel group from the mesh. + self.set_tensor_parallel_group(tp_mesh.get_group()) + + # Construct TP-sharded DTensors. + if self.softmax_type == "learnable": + from torch.distributed.tensor.placement_types import Shard + + self.softmax_offset = _convert_param_to_dtensor_param( + self.softmax_offset, tp_mesh, placements=(Shard(dim=0),) + ) + def set_context_parallel_group( self, cp_group: Union[dist_group_type, List[dist_group_type], None], @@ -802,6 +864,17 @@ def set_meta_tensor(self, fwd: bool, recipe: Union[Recipe, List[Recipe]]) -> Non for recipe_state in recipe_states: self.quantizers[fp8_meta_tensor_key].extend(recipe_state.make_quantizers()) + def _get_softmax_offset(self) -> torch.Tensor: + """Get the softmax offset.""" + softmax_offset = ( + self.softmax_offset.reshape(1, -1, 1, 1).to(torch.float32) + if self.softmax_offset is not None + else None + ) + if isinstance(softmax_offset, DTensor): + softmax_offset = softmax_offset.to_local() + return softmax_offset + @no_torch_dynamo(recursive=False) def forward( self, @@ -1434,11 +1507,7 @@ def forward( ) # run attention - softmax_offset = ( - self.softmax_offset.reshape(1, -1, 1, 1).to(torch.float32) - if self.softmax_offset is not None - else None - ) + softmax_offset = self._get_softmax_offset() if use_flash_attention: if core_attention_bias_type == "alibi": diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 01c4955d78..80c9517a9b 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -7,6 +7,7 @@ import collections from typing import Callable, List, Optional, Tuple, Union import torch +from torch.distributed import DeviceMesh from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor @@ -191,6 +192,19 @@ class MultiheadAttention(torch.nn.Module): ``set_tensor_parallel_group(tp_group)`` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. + tp_mesh : Optional[DeviceMesh], default = None + A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"]. + Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP. + weight_mesh : Optional[DeviceMesh], default = None + A 1-D DeviceMesh containing a weight-sharding mesh dimension. Only required + when using the FP8 Current (per-tensor) Scaling recipe on sharded DTensor + parameters and if the DTensor DeviceMesh includes dimensions that do not + shard weights, such as in the case of HSDP (DP-Replicate x DP-Shard). + For example: + - device_mesh["dp"] for FSDP. + - device_mesh["dp_cp"] if using CP ranks in FSDP. + - device_mesh["tp"] if using TP. + - device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP. Optimization parameters ----------------------- @@ -286,6 +300,8 @@ def __init__( seq_length: Optional[int] = None, micro_batch_size: Optional[int] = None, softmax_type: str = "vanilla", + tp_mesh: Optional[DeviceMesh] = None, + weight_mesh: Optional[DeviceMesh] = None, ) -> None: super().__init__() @@ -357,6 +373,13 @@ def __init__( self.q_norm, self.k_norm = self._create_qk_norm_modules( qk_norm_type, qk_norm_eps, device, seq_length, micro_batch_size ) + if tp_mesh is not None or weight_mesh is not None: + # Apply DeviceMesh and DTensor-related modifications. + # Only necessary for trainable weighted norms. + if hasattr(self.q_norm, "set_device_mesh"): + self.q_norm.set_device_mesh(tp_mesh=tp_mesh, weight_mesh=weight_mesh) + if hasattr(self.k_norm, "set_device_mesh"): + self.k_norm.set_device_mesh(tp_mesh=tp_mesh, weight_mesh=weight_mesh) qkv_parallel_mode = "column" if set_parallel_mode else None @@ -389,6 +412,8 @@ def __init__( normalization=normalization, ub_name="qkv", name=name + ".layernorm_linear_qkv" if name is not None else None, + tp_mesh=tp_mesh, + weight_mesh=weight_mesh, **common_gemm_kwargs, ) else: @@ -401,6 +426,8 @@ def __init__( parallel_mode=qkv_parallel_mode, parameters_split=parameters_split, name=name + ".linear_qkv" if name is not None else None, + tp_mesh=tp_mesh, + weight_mesh=weight_mesh, **common_gemm_kwargs, ) elif self.attention_type == "cross": @@ -423,6 +450,8 @@ def __init__( normalization=normalization, ub_name="qkv", name=name + ".layernorm_linear_q" if name is not None else None, + tp_mesh=tp_mesh, + weight_mesh=weight_mesh, **common_gemm_kwargs, ) else: @@ -433,6 +462,8 @@ def __init__( bias=bias, return_bias=False, parallel_mode=qkv_parallel_mode, + tp_mesh=tp_mesh, + weight_mesh=weight_mesh, **common_gemm_kwargs, ) self.key_value = Linear( @@ -444,6 +475,8 @@ def __init__( parallel_mode=qkv_parallel_mode, parameters_split=("key", "value") if not fuse_qkv_params else None, name=name + ".linear_kv" if name is not None else None, + tp_mesh=tp_mesh, + weight_mesh=weight_mesh, **common_gemm_kwargs, ) @@ -461,6 +494,8 @@ def __init__( layer_number=self.layer_number, attention_type=self.attention_type, softmax_type=self.softmax_type, + tp_mesh=tp_mesh, + weight_mesh=weight_mesh, ) # Linear @@ -475,6 +510,8 @@ def __init__( ub_overlap_ag=ub_overlap_ag, ub_name="proj", name=name + ".proj" if name is not None else None, + tp_mesh=tp_mesh, + weight_mesh=weight_mesh, **common_gemm_kwargs, ) @@ -562,6 +599,62 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N """ self.tp_group = tp_group + def set_device_mesh( + self, + tp_mesh: Optional[DeviceMesh] = None, + weight_mesh: Optional[DeviceMesh] = None, + ) -> None: + """ + Set DeviceMesh(s) used for sharding weights and convert main weights into DTensor + depending on the TransformerEngine class to support FSDP-TP sharding with FSDP2. + + TransformerEngine manages tensor parallel mechanics, while DTensor offers seamless + integration with Torch DCP checkpointing. This method should only be invoked when + using DTensor parameters, e.g. when using FSDP2 or DCP. + + When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically + convert them into FSDP-TP strided or non-strided shards depending on the current + sharding dimension and factor of the DTensor. When the sharding dimension of FSDP + matches that of TP, FSDP uses a _StridedShard placement type instead of Shard. + This experimental FSDP-TP logic presides in this FSDP2 initialization function: + ``torch.distributed.fsdp._fully_shard._fsdp_param._init_sharded_param`` + + Parameters + ---------- + tp_mesh : Optional[DeviceMesh] + A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"]. + Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP. + weight_mesh : Optional[DeviceMesh] + A 1-D DeviceMesh containing a weight-sharding mesh dimension. Only required + when using the FP8 Current (per-tensor) Scaling recipe on sharded DTensor + parameters and if the DTensor DeviceMesh includes dimensions that do not + shard weights, such as in the case of HSDP (DP-Replicate x DP-Shard). + For example: + - device_mesh["dp"] for FSDP. + - device_mesh["dp_cp"] if using CP ranks in FSDP. + - device_mesh["tp"] if using TP. + - device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP. + """ + if tp_mesh is not None: + # Validate TP DeviceMesh / Group. Must be consistent with tp_size. + assert ( + tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(), + ( + f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) " + f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})." + ), + ) + # Set the tensor parallel group from the mesh. + self.set_tensor_parallel_group(tp_mesh.get_group()) + + if tp_mesh is not None or weight_mesh is not None: + # Iterate through child sub-modules without deep recursion. + # Automatically detects TransformerEngine TP modules and + # the capability to call this method at any level. + for name, child in self.named_children(): + if hasattr(child, "set_device_mesh"): + child.set_device_mesh(tp_mesh, weight_mesh) + def set_context_parallel_group( self, cp_group: Union[dist_group_type, List[dist_group_type], None], diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index f269e21b8c..e0def07541 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1934,6 +1934,40 @@ def _get_module_fsdp_state(module): return fsdp_state +def _convert_param_to_dtensor_param( + param: torch.nn.Parameter, + device_mesh: torch.distributed.DeviceMesh, + placements: Tuple[torch.distributed.tensor.placement_types.Placement], + shape: Optional[torch.Size] = None, + stride: Optional[Tuple[int]] = None, +): + """Convert the parameter into a DTensor.""" + from torch.distributed.tensor import DTensor + + # If the parameter is already a DTensor, extract local Tensor. + # We overwrite the original DTensor's distributed configuration. + param_tensor = param + if isinstance(param, DTensor): + param_tensor = param.to_local() + # Convert the parameter to a DTensor. + new_param = torch.nn.Parameter( + DTensor.from_local( + param_tensor, + device_mesh, + placements=placements, + shape=shape, + stride=stride, + ) + ) + # Inherit attributes of the original Parameter. + # For example, "param_init_meta" or "tensor_model_parallel". + for key, val in param.__dict__.items(): + if not hasattr(new_param, key): + # Set the original attribute. + setattr(new_param, key, val) + return new_param + + def _fsdp_scatter_tensors( fsdp_group: dist_group_type, *tensors: torch.Tensor, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 09b12afa21..b7a0e9b8ec 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -35,6 +35,7 @@ is_fp8_activation_recompute_enabled, in_fp8_activation_recompute_phase, _fsdp_gather_tensors, + _convert_param_to_dtensor_param, ) from ..constants import dist_group_type from ..cpp_extensions.gemm import _NUM_MAX_UB_STREAMS @@ -618,7 +619,7 @@ def __init__(self, name: Optional[str] = None) -> None: self.fp8_meta["fp8_checkpoint"] = False self.fp8_meta["fp8_group"] = None self.fp8_meta_tensors_initialized = False - self.quantizers = {"scaling_fwd": {}, "scaling_bwd": {}} + self.quantizers = {"scaling_fwd": [], "scaling_bwd": []} self.tp_group = None self.tp_size = 1 self.sequence_parallel = False @@ -1041,6 +1042,11 @@ def prepare_forward( ) self.fast_setattr("forwarded_at_least_once", True) + # If the input is a DTensor, localize it. DTensor.to_local() is differentiable. + # TransformerEngine C++ kernels are not designed for the DTensor API. + if isinstance(inp, DTensor): + inp = inp.to_local() + # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): delayed_scaling_recipe = self.fp8_meta["recipe"].delayed() @@ -1221,7 +1227,7 @@ def grad_output_preprocess( def register_parameter(self, name, param, **kwargs): """ Thin wrapper around PyTorch parameter registration to stash additional parameter - metedata used in deferred initialization. + metadata used in deferred initialization. """ super().register_parameter(name, param) # Initialize param_init_meta exactly once during the init. FSDP2 can call @@ -1277,14 +1283,15 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: raise RuntimeError("Weight quantizer has not been initialized") quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) quantizer.internal = False - if is_dtensor and isinstance(quantizer, Float8CurrentScalingQuantizer): + if ( + is_dtensor + and isinstance(quantizer, Float8CurrentScalingQuantizer) + and quantizer.amax_reduction_group is None + ): + # If the amax_reduction_group is not set by `set_device_mesh`, + # then default to the DTensor's full DeviceMesh. device_mesh = dtensor_param.device_mesh - amax_reduction_group = ( - device_mesh.get_group(mesh_dim="shard") - if device_mesh.ndim > 1 - else device_mesh.get_group() - ) - quantizer.amax_reduction_group = amax_reduction_group + quantizer.amax_reduction_group = device_mesh.get_group() quantizer.with_amax_reduction = True # Quantize parameter param = quantizer(param) @@ -1294,15 +1301,15 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # re-applying the nn.Parameter() wrap is a no-op when the input is already # a parameter so we always re-apply it just for extra safety. if is_dtensor: - # recreate the DTensor from the parameter. - dtensor_param = DTensor.from_local( + # Recreate the DTensor from the Parameter, inheriting + # all attributes originally set on the Parameter. + dtensor_param = _convert_param_to_dtensor_param( param, device_mesh=dtensor_param.device_mesh, placements=dtensor_param.placements, shape=dtensor_param.size(), stride=dtensor_param.stride(), ) - dtensor_param = torch.nn.Parameter(dtensor_param) else: param = torch.nn.Parameter(param) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 2f859e748b..2eb9983aed 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -10,6 +10,8 @@ import functools import torch +from torch.distributed import DeviceMesh +from torch.distributed.tensor import DTensor import transformer_engine_torch as tex @@ -37,6 +39,7 @@ get_distributed_world_size, is_fp8_activation_recompute_enabled, in_fp8_activation_recompute_phase, + _convert_param_to_dtensor_param, ) from ..cpp_extensions import ( general_grouped_gemm, @@ -598,7 +601,8 @@ class GroupedLinear(TransformerEngineBaseModule): Notes ----- GroupedLinear doesn't really handle the TP communications inside. The ``tp_size`` and - ``parallel_mode`` are used to determine the shapes of weights and biases. + ``parallel_mode`` are used to determine the shapes of weights and biases, while ``tp_mesh`` + and ``weight_mesh`` support DTensor compatibility with FSDP2 and DCP. The TP communication should be handled in the dispatch and combine stages of MoE models. """ @@ -625,6 +629,8 @@ def __init__( delay_wgrad_compute: bool = False, save_original_input: bool = False, name: Optional[str] = None, + tp_mesh: Optional[DeviceMesh] = None, + weight_mesh: Optional[DeviceMesh] = None, ) -> None: super().__init__(name) @@ -726,6 +732,9 @@ def __init__( self.init_fp8_metadata(num_gemms=self.num_gemms) is_meta = torch.device(device).type == "meta" + if tp_mesh is not None or weight_mesh is not None: + # Apply DeviceMesh and DTensor-related modifications. + self.set_device_mesh(tp_mesh=tp_mesh, weight_mesh=weight_mesh) self.reset_parameters(defer_init=is_meta) if self.wgrad_store.delay_wgrad_compute(): @@ -758,7 +767,7 @@ def make_grouped_weights(self, defer_init=False) -> None: else None ) if recipe is not None and (recipe.delayed() or recipe.float8_current_scaling()): - self.set_tensor_parallel_attributes(defer_init=defer_init) + self._set_tensor_parallel_attributes(defer_init=defer_init) return weights = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] @@ -782,15 +791,27 @@ def make_grouped_weights(self, defer_init=False) -> None: # Re-register the grouped weights as parameters. for i in range(self.num_gemms): + # Prepare the new Parameter. + new_param = torch.nn.Parameter(grouped_weights.quantized_tensors[i]) + if isinstance(getattr(self, f"weight{i}"), DTensor): + # Convert to DTensor with properties equivalent to the original DTensor. + orig_dtensor_param = getattr(self, f"weight{i}") + new_param = _convert_param_to_dtensor_param( + new_param, + device_mesh=orig_dtensor_param.device_mesh, + placements=orig_dtensor_param.placements, + shape=orig_dtensor_param.size(), + stride=orig_dtensor_param.stride(), + ) self.register_parameter( f"weight{i}", - torch.nn.Parameter(grouped_weights.quantized_tensors[i]), + new_param, init_fn=self.init_method, get_rng_state_tracker=self.get_rng_state_tracker, fp8_meta_index=self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"], ) - self.set_tensor_parallel_attributes(defer_init=defer_init) + self._set_tensor_parallel_attributes(defer_init=defer_init) def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) @@ -798,7 +819,95 @@ def reset_parameters(self, defer_init=False): if bool(int(os.getenv("NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS", "0"))): self.make_grouped_weights(defer_init=defer_init) - def set_tensor_parallel_attributes(self, defer_init=False) -> None: + def set_device_mesh( + self, + tp_mesh: Optional[DeviceMesh] = None, + weight_mesh: Optional[DeviceMesh] = None, + ) -> None: + """ + Set DeviceMesh(s) used for sharding weights and convert main weights into DTensor + depending on the TransformerEngine class to support FSDP-TP sharding with FSDP2. + + TransformerEngine manages tensor parallel mechanics, while DTensor offers seamless + integration with Torch DCP checkpointing. This method should only be invoked when + using DTensor parameters, e.g. when using FSDP2 or DCP. + + When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically + convert them into FSDP-TP strided or non-strided shards depending on the current + sharding dimension and factor of the DTensor. When the sharding dimension of FSDP + matches that of TP, FSDP uses a _StridedShard placement type instead of Shard. + This experimental FSDP-TP logic presides in this FSDP2 initialization function: + ``torch.distributed.fsdp._fully_shard._fsdp_param._init_sharded_param`` + + Parameters + ---------- + tp_mesh : Optional[DeviceMesh] + A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"]. + Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP. + weight_mesh : Optional[DeviceMesh] + A 1-D DeviceMesh containing a weight-sharding mesh dimension. Only required + when using the FP8 Current (per-tensor) Scaling recipe on sharded DTensor + parameters and if the DTensor DeviceMesh includes dimensions that do not + shard weights, such as in the case of HSDP (DP-Replicate x DP-Shard). + For example: + - device_mesh["dp"] for FSDP. + - device_mesh["dp_cp"] if using CP ranks in FSDP. + - device_mesh["tp"] if using TP. + - device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP. + """ + if tp_mesh is not None: + # Validate TP DeviceMesh / Group. Must be consistent with tp_size. + assert ( + tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(), + ( + f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) " + f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})." + ), + ) + # Set the tensor parallel group from the mesh. + self.set_tensor_parallel_group(tp_mesh.get_group()) + + # Construct TP-sharded DTensors. + from torch.distributed.tensor.placement_types import Replicate, Shard + + for weight in self.weight_names: + param = getattr(self, weight) + placements = (Replicate(),) + if self.parallel_mode == "column": + placements = (Shard(dim=0),) + elif self.parallel_mode == "row": + placements = (Shard(dim=1),) + setattr( + self, + weight, + _convert_param_to_dtensor_param(param, tp_mesh, placements=placements), + ) + for bias in self.bias_names: + param = getattr(self, bias) + placements = (Replicate(),) + if self.parallel_mode == "column": + placements = (Shard(dim=0),) + setattr( + self, + bias, + _convert_param_to_dtensor_param(param, tp_mesh, placements=placements), + ) + + # Set amax_reduction_group to the FSDP and/or TP sharding mesh + # for per-tensor scaling recipes. Parameters must be registered. + if weight_mesh is not None and self.quantizers["scaling_fwd"]: + for weight in self.weight_names: + # Get fp8_meta_index and associated quantizer. + fp8_meta_index = self.param_init_meta[weight].fp8_meta_index + quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] + if isinstance(quantizer, Float8CurrentScalingQuantizer): + # If not set, will default to DTensor.device_mesh.get_group()! + # MUST be provided when using HSDP (DP-Replicate) or when the + # DeviceMesh includes dimensions that do not shard weights! + quantizer.amax_reduction_group = weight_mesh.get_group() + quantizer.with_amax_reduction = True + + def _set_tensor_parallel_attributes(self, defer_init=False) -> None: """Set attributes needed for TP""" if not defer_init: @@ -865,7 +974,7 @@ def forward( inp = self.prepare_forward(inp, num_gemms=self.num_gemms) try: weight_tensors = self._get_weight_tensors() - bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] + bias_tensors = self._get_bias_tensors() quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() @@ -985,7 +1094,12 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" - weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] + weight_tensors = [] + for i in range(self.num_gemms): + weight = getattr(self, f"weight{i}") + if isinstance(weight, DTensor): + weight = weight.to_local() + weight_tensors.append(weight) if not self.fp8 and any(isinstance(w, QuantizedTensorStorage) for w in weight_tensors): warnings.warn( "You are using quantized weights without quantized compute. " @@ -997,6 +1111,16 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage ] return weight_tensors + def _get_bias_tensors(self) -> List[torch.Tensor]: + """Get the bias tensors of the module.""" + bias_tensors = [] + for i in range(self.num_gemms): + bias = getattr(self, f"bias{i}") + if isinstance(bias, DTensor): + bias = bias.to_local() + bias_tensors.append(bias) + return bias_tensors + def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" if not self.fp8 and not self.fp8_calibration and not self.primary_weights_in_fp8: diff --git a/transformer_engine/pytorch/module/layernorm.py b/transformer_engine/pytorch/module/layernorm.py index d4f0a78ba2..d45e3c717c 100644 --- a/transformer_engine/pytorch/module/layernorm.py +++ b/transformer_engine/pytorch/module/layernorm.py @@ -7,6 +7,7 @@ from typing import Iterable, Optional, Union import torch +from torch.distributed import DeviceMesh from transformer_engine.pytorch.ops import LayerNorm as _LayerNormOp @@ -134,6 +135,48 @@ def reset_parameters(self, defer_init: Optional[bool] = None) -> None: self.weight.sequence_parallel = self.sequence_parallel self.bias.sequence_parallel = self.sequence_parallel + def set_device_mesh( + self, + tp_mesh: Optional[DeviceMesh] = None, + weight_mesh: Optional[DeviceMesh] = None, + ) -> None: + """ + Set DeviceMesh(s) used for sharding weights and convert main weights into DTensor + depending on the TransformerEngine class to support FSDP-TP sharding with FSDP2. + + TransformerEngine manages tensor parallel mechanics, while DTensor offers seamless + integration with Torch DCP checkpointing. This method should only be invoked when + using DTensor parameters, e.g. when using FSDP2 or DCP. + + When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically + convert them into FSDP-TP strided or non-strided shards depending on the current + sharding dimension and factor of the DTensor. When the sharding dimension of FSDP + matches that of TP, FSDP uses a _StridedShard placement type instead of Shard. + This experimental FSDP-TP logic presides in this FSDP2 initialization function: + ``torch.distributed.fsdp._fully_shard._fsdp_param._init_sharded_param`` + + Parameters + ---------- + tp_mesh : Optional[DeviceMesh] + A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"]. + Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP. + weight_mesh : Optional[DeviceMesh] + Quantized DTensor parameters are currently not supported for FusibleOperation(s), + and this mesh is not used. + """ + if tp_mesh is not None: + # Construct TP-Replicate DTensors. Used to shim non-TP parameters for compatibility + # with DTensor parameters in TP layers to support DTensor operations. + from transformer_engine.pytorch.distributed import _convert_param_to_dtensor_param + from torch.distributed.tensor.placement_types import Replicate + + self.weight = _convert_param_to_dtensor_param( + self.weight, tp_mesh, placements=(Replicate(),) + ) + self.bias = _convert_param_to_dtensor_param( + self.bias, tp_mesh, placements=(Replicate(),) + ) + @property def fwd_ln_sm_margin(self) -> int: """Shim for backward compatibility""" diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 702916696b..651ef5bee3 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -11,6 +11,8 @@ import torch from torch.nn import init +from torch.distributed import DeviceMesh +from torch.distributed.tensor import DTensor import transformer_engine_torch as tex @@ -51,6 +53,7 @@ in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, _fsdp_gather_tensors, + _convert_param_to_dtensor_param, ) from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo @@ -64,6 +67,7 @@ restore_from_saved, ) from ...debug.pytorch.debug_state import TEDebugState +from ..tensor.float8_tensor import Float8CurrentScalingQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..cpu_offload import ( is_cpu_offload_enabled, @@ -1098,6 +1102,19 @@ class LayerNormLinear(TransformerEngineBaseModule): used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described `here `_. When set to ``None``, no communication is performed. + tp_mesh : Optional[DeviceMesh], default = None + A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"]. + Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP. + weight_mesh : Optional[DeviceMesh], default = None + A 1-D DeviceMesh containing a weight-sharding mesh dimension. Only required + when using the FP8 Current (per-tensor) Scaling recipe on sharded DTensor + parameters and if the DTensor DeviceMesh includes dimensions that do not + shard weights, such as in the case of HSDP (DP-Replicate x DP-Shard). + For example: + - device_mesh["dp"] for FSDP. + - device_mesh["dp_cp"] if using CP ranks in FSDP. + - device_mesh["tp"] if using TP. + - device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP. Optimization parameters ----------------------- @@ -1159,6 +1176,8 @@ def __init__( delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, name: Optional[str] = None, + tp_mesh: Optional[DeviceMesh] = None, + weight_mesh: Optional[DeviceMesh] = None, ) -> None: super().__init__(name) @@ -1379,6 +1398,9 @@ def __init__( if with_fp8_params: self.init_fp8_metadata() + if tp_mesh is not None or weight_mesh is not None: + # Apply DeviceMesh and DTensor-related modifications. + self.set_device_mesh(tp_mesh=tp_mesh, weight_mesh=weight_mesh) self.reset_parameters(defer_init=device == "meta") # For RPL, bias has to be added after TP collectives @@ -1429,29 +1451,107 @@ def reset_layer_norm_parameters(self) -> None: def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) + self._set_tensor_parallel_attributes(defer_init=defer_init) - if not defer_init: - # Set parallelism attributes for layer norm parameters - setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) - if self.normalization != "RMSNorm": - setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) + def set_device_mesh( + self, + tp_mesh: Optional[DeviceMesh] = None, + weight_mesh: Optional[DeviceMesh] = None, + ) -> None: + """ + Set DeviceMesh(s) used for sharding weights and convert main weights into DTensor + depending on the TransformerEngine class to support FSDP-TP sharding with FSDP2. - # Set parallelism attributes for linear weights + TransformerEngine manages tensor parallel mechanics, while DTensor offers seamless + integration with Torch DCP checkpointing. This method should only be invoked when + using DTensor parameters, e.g. when using FSDP2 or DCP. + + When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically + convert them into FSDP-TP strided or non-strided shards depending on the current + sharding dimension and factor of the DTensor. When the sharding dimension of FSDP + matches that of TP, FSDP uses a _StridedShard placement type instead of Shard. + This experimental FSDP-TP logic presides in this FSDP2 initialization function: + ``torch.distributed.fsdp._fully_shard._fsdp_param._init_sharded_param`` + + Parameters + ---------- + tp_mesh : Optional[DeviceMesh] + A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"]. + Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP. + weight_mesh : Optional[DeviceMesh] + A 1-D DeviceMesh containing a weight-sharding mesh dimension. Only required + when using the FP8 Current (per-tensor) Scaling recipe on sharded DTensor + parameters and if the DTensor DeviceMesh includes dimensions that do not + shard weights, such as in the case of HSDP (DP-Replicate x DP-Shard). + For example: + - device_mesh["dp"] for FSDP. + - device_mesh["dp_cp"] if using CP ranks in FSDP. + - device_mesh["tp"] if using TP. + - device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP. + """ + if tp_mesh is not None: + # Validate TP DeviceMesh / Group. Must be consistent with tp_size. + assert ( + tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(), + ( + f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) " + f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})." + ), + ) + # Set the tensor parallel group from the mesh. + self.set_tensor_parallel_group(tp_mesh.get_group()) + + # Construct TP-sharded DTensors. + from torch.distributed.tensor.placement_types import Replicate, Shard + + # Linear for weight in self.weight_names: - set_tensor_model_parallel_attributes( - tensor=getattr(self, weight), - is_parallel=True, - dim=1 if self.parallel_mode == "row" else 0, - stride=1, + param = getattr(self, weight) + placements = (Replicate(),) + if self.parallel_mode == "column": + placements = (Shard(dim=0),) + elif self.parallel_mode == "row": + placements = (Shard(dim=1),) + setattr( + self, + weight, + _convert_param_to_dtensor_param(param, tp_mesh, placements=placements), + ) + for bias in self.bias_names: + param = getattr(self, bias) + placements = (Replicate(),) + if self.parallel_mode == "column": + placements = (Shard(dim=0),) + setattr( + self, + bias, + _convert_param_to_dtensor_param(param, tp_mesh, placements=placements), + ) + # LayerNorm + placements = (Replicate(),) + if self.parallel_mode == "row": + placements = (Shard(dim=0),) + self.layer_norm_weight = _convert_param_to_dtensor_param( + self.layer_norm_weight, tp_mesh, placements=placements + ) + if self.layer_norm_bias is not None: + self.layer_norm_bias = _convert_param_to_dtensor_param( + self.layer_norm_bias, tp_mesh, placements=placements ) - # Set parallelism attributes for linear biases - if self.use_bias: - for bias in self.bias_names: - if self.parallel_mode == "row": - setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel) - elif self.parallel_mode == "column": - set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) + # Set amax_reduction_group to the FSDP and/or TP sharding mesh + # for per-tensor scaling recipes. Parameters must be registered. + if weight_mesh is not None and self.quantizers["scaling_fwd"]: + for weight in self.weight_names: + # Get fp8_meta_index and associated quantizer. + fp8_meta_index = self.param_init_meta[weight].fp8_meta_index + quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] + if isinstance(quantizer, Float8CurrentScalingQuantizer): + # If not set, will default to DTensor.device_mesh.get_group()! + # MUST be provided when using HSDP (DP-Replicate) or when the + # DeviceMesh includes dimensions that do not shard weights! + quantizer.amax_reduction_group = weight_mesh.get_group() + quantizer.with_amax_reduction = True @no_torch_dynamo() def forward( @@ -1512,8 +1612,10 @@ def forward( ) try: - # Get concatenated weight and bias tensors + # Get concatenated weight and bias tensors. weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + # Get the layernorm weight and bias tensors. + ln_weight, ln_bias = self._get_layernorm_weight_and_bias() quantizers = ( self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled) @@ -1583,8 +1685,8 @@ def forward( out = fwd_fn( *autograd_ctx, inp, - self.layer_norm_weight, - self.layer_norm_bias, + ln_weight, + ln_bias, weight_tensor, bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None, non_tensor_args, @@ -1649,15 +1751,16 @@ def _get_debug_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): for name, q in zip(names, original_quantizers) ) - def _get_weight_and_bias_tensors(self): - # Get concatenated weight and bias tensors + def _get_weight_and_bias_tensors(self) -> Tuple[torch.Tensor, torch.Tensor]: + # Get concatenated weight and bias tensors. unfused_weights = self._get_weight_tensors() - weight_tensor = noop_cat(unfused_weights) if self.use_bias: - bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) + unfused_biases = self._get_bias_tensors() + bias_tensor = noop_cat(unfused_biases) else: - bias_tensor = getattr(self, self.bias_names[0]) # Unused + # Unused. + bias_tensor = getattr(self, self.bias_names[0]) return weight_tensor, bias_tensor def onnx_forward( @@ -1684,10 +1787,11 @@ def onnx_forward( inp_dtype = inp.dtype weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() + ln_weight, ln_bias = self._get_layernorm_weight_and_bias() ln_out, ln_out_return = onnx_layernorm( inp, - self.layer_norm_weight, - self.layer_norm_bias, + ln_weight, + ln_bias, self.eps, self.normalization, self.zero_centered_gamma, @@ -1787,7 +1891,12 @@ def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" - unfused_weights = [getattr(self, name) for name in self.weight_names] + unfused_weights = [] + for name in self.weight_names: + weight = getattr(self, name) + if isinstance(weight, DTensor): + weight = weight.to_local() + unfused_weights.append(weight) if any(isinstance(w, QuantizedTensor) for w in unfused_weights): if self.fp8: if len(unfused_weights) != 1: @@ -1802,6 +1911,26 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage unfused_weights = [w.dequantize() for w in unfused_weights] return unfused_weights + def _get_bias_tensors(self) -> List[torch.Tensor]: + """Get the bias tensors of the module.""" + unfused_biases = [] + for name in self.bias_names: + bias = getattr(self, name) + if isinstance(bias, DTensor): + bias = bias.to_local() + unfused_biases.append(bias) + return unfused_biases + + def _get_layernorm_weight_and_bias(self) -> List[Optional[torch.Tensor]]: + """Get the weight and bias of the layer norm.""" + ln_weight = self.layer_norm_weight + if isinstance(ln_weight, DTensor): + ln_weight = ln_weight.to_local() + ln_bias = self.layer_norm_bias + if isinstance(ln_bias, DTensor): + ln_bias = ln_bias.to_local() + return [ln_weight, ln_bias] + def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" if not self.fp8 and not self.fp8_calibration: @@ -1809,3 +1938,28 @@ def _get_weight_quantizers(self) -> List[Quantizer]: weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True return [weight_quantizer] + + def _set_tensor_parallel_attributes(self, defer_init=False) -> None: + """Set tensor and sequence parallelism attributes.""" + if not defer_init: + # Set parallelism attributes for layer norm parameters + setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) + if self.normalization != "RMSNorm": + setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) + + # Set parallelism attributes for linear weights + for weight in self.weight_names: + set_tensor_model_parallel_attributes( + tensor=getattr(self, weight), + is_parallel=True, + dim=1 if self.parallel_mode == "row" else 0, + stride=1, + ) + + # Set parallelism attributes for linear biases + if self.use_bias: + for bias in self.bias_names: + if self.parallel_mode == "row": + setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel) + elif self.parallel_mode == "column": + set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4532ea60e7..51e523ed09 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -12,6 +12,8 @@ import torch from torch.nn.parameter import Parameter from torch.nn import init +from torch.distributed import DeviceMesh +from torch.distributed.tensor import DTensor import transformer_engine_torch as tex @@ -58,6 +60,7 @@ _fsdp_scatter_tensors, _get_cuda_rng_state, _set_cuda_rng_state, + _convert_param_to_dtensor_param, ) from ..constants import dist_group_type from ..jit import no_torch_dynamo @@ -1368,7 +1371,7 @@ def fc2_wgrad_gemm( # Make sure required data is available if ctx.fc1_weight_quantizer is not None and isinstance( - ctx.fc1_weight_quantizer, QuantizedTensorStorage + ctx.fc1_weight, QuantizedTensorStorage ): ctx.fc1_weight.update_usage(columnwise_usage=True) @@ -1721,6 +1724,19 @@ class LayerNormMLP(TransformerEngineBaseModule): ``set_tensor_parallel_group(tp_group)`` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. + tp_mesh : Optional[DeviceMesh], default = None + A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"]. + Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP. + weight_mesh : Optional[DeviceMesh], default = None + A 1-D DeviceMesh containing a weight-sharding mesh dimension. Only required + when using the FP8 Current (per-tensor) Scaling recipe on sharded DTensor + parameters and if the DTensor DeviceMesh includes dimensions that do not + shard weights, such as in the case of HSDP (DP-Replicate x DP-Shard). + For example: + - device_mesh["dp"] for FSDP. + - device_mesh["dp_cp"] if using CP ranks in FSDP. + - device_mesh["tp"] if using TP. + - device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP. Optimization parameters ----------------------- @@ -1798,6 +1814,8 @@ def __init__( delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, checkpoint: bool = False, + tp_mesh: Optional[DeviceMesh] = None, + weight_mesh: Optional[DeviceMesh] = None, ) -> None: super().__init__(name) @@ -1940,6 +1958,9 @@ def __init__( if with_fp8_params: self.init_fp8_metadata(num_gemms=2) + if tp_mesh is not None or weight_mesh is not None: + # Apply DeviceMesh and DTensor-related modifications. + self.set_device_mesh(tp_mesh=tp_mesh, weight_mesh=weight_mesh) self.reset_parameters(defer_init=device == "meta") # For RPL, bias has to be added after TP collectives @@ -1996,20 +2017,95 @@ def reset_layer_norm_parameters(self) -> None: def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) + self._set_tensor_parallel_attributes(defer_init=defer_init) - if not defer_init: - # Set parallel attributes for layer norm parameters - setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) - if self.normalization != "RMSNorm": - setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) + def set_device_mesh( + self, + tp_mesh: Optional[DeviceMesh] = None, + weight_mesh: Optional[DeviceMesh] = None, + ) -> None: + """ + Set DeviceMesh(s) used for sharding weights and convert main weights into DTensor + depending on the TransformerEngine class to support FSDP-TP sharding with FSDP2. - # Set parallel attributes for linear parameters - set_tensor_model_parallel_attributes(self.fc1_weight, True, 0, 1) - set_tensor_model_parallel_attributes(self.fc2_weight, True, 1, 1) - if self.use_bias: - set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1) - if self.set_parallel_mode: - setattr(self.fc2_bias, "sequence_parallel", self.sequence_parallel) + TransformerEngine manages tensor parallel mechanics, while DTensor offers seamless + integration with Torch DCP checkpointing. This method should only be invoked when + using DTensor parameters, e.g. when using FSDP2 or DCP. + + When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically + convert them into FSDP-TP strided or non-strided shards depending on the current + sharding dimension and factor of the DTensor. When the sharding dimension of FSDP + matches that of TP, FSDP uses a _StridedShard placement type instead of Shard. + This experimental FSDP-TP logic presides in this FSDP2 initialization function: + ``torch.distributed.fsdp._fully_shard._fsdp_param._init_sharded_param`` + + Parameters + ---------- + tp_mesh : Optional[DeviceMesh] + A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"]. + Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP. + weight_mesh : Optional[DeviceMesh] + A 1-D DeviceMesh containing a weight-sharding mesh dimension. Only required + when using the FP8 Current (per-tensor) Scaling recipe on sharded DTensor + parameters and if the DTensor DeviceMesh includes dimensions that do not + shard weights, such as in the case of HSDP (DP-Replicate x DP-Shard). + For example: + - device_mesh["dp"] for FSDP. + - device_mesh["dp_cp"] if using CP ranks in FSDP. + - device_mesh["tp"] if using TP. + - device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP. + """ + if tp_mesh is not None: + # Validate TP DeviceMesh / Group. Must be consistent with tp_size. + assert ( + tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(), + ( + f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) " + f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})." + ), + ) + # Set the tensor parallel group from the mesh. + self.set_tensor_parallel_group(tp_mesh.get_group()) + + # Construct TP-sharded DTensors. + from torch.distributed.tensor.placement_types import Replicate, Shard + + # FC1 -> Column-Parallel -> Shard(dim=0) + self.fc1_weight = _convert_param_to_dtensor_param( + self.fc1_weight, tp_mesh, placements=(Shard(dim=0),) + ) + self.fc1_bias = _convert_param_to_dtensor_param( + self.fc1_bias, tp_mesh, placements=(Shard(dim=0),) + ) + # FC2 Weight -> Row-Parallel -> Shard(dim=1) + self.fc2_weight = _convert_param_to_dtensor_param( + self.fc2_weight, tp_mesh, placements=(Shard(dim=1),) + ) + # LN & FC2 Bias -> Replicate() + self.fc2_bias = _convert_param_to_dtensor_param( + self.fc2_bias, tp_mesh, placements=(Replicate(),) + ) + self.layer_norm_weight = _convert_param_to_dtensor_param( + self.layer_norm_weight, tp_mesh, placements=(Replicate(),) + ) + if self.layer_norm_bias is not None: + self.layer_norm_bias = _convert_param_to_dtensor_param( + self.layer_norm_bias, tp_mesh, placements=(Replicate(),) + ) + + # Set amax_reduction_group to the FSDP and/or TP sharding mesh + # for per-tensor scaling recipes. Parameters must be registered. + if weight_mesh is not None and self.quantizers["scaling_fwd"]: + for weight in ["fc1_weight", "fc2_weight"]: + # Get fp8_meta_index and associated quantizer. + fp8_meta_index = self.param_init_meta[weight].fp8_meta_index + quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] + if isinstance(quantizer, Float8CurrentScalingQuantizer): + # If not set, will default to DTensor.device_mesh.get_group()! + # MUST be provided when using HSDP (DP-Replicate) or when the + # DeviceMesh includes dimensions that do not shard weights! + quantizer.amax_reduction_group = weight_mesh.get_group() + quantizer.with_amax_reduction = True @no_torch_dynamo() def forward( @@ -2088,8 +2184,9 @@ def forward( # Get weight tensors fc1_weight, fc2_weight = self._get_weight_tensors() - fc1_bias = self.fc1_bias if self.use_bias else None - fc2_bias = self.fc2_bias if self.use_bias else None + fc1_bias, fc2_bias = self._get_bias_tensors() + ln_weight, ln_bias = self._get_layernorm_weight_and_bias() + if not self.fp8: if isinstance(fc1_weight, Float8Tensor): fc1_weight = fc1_weight.dequantize() @@ -2159,8 +2256,8 @@ def forward( out = fwd_fn( *autograd_ctx, inp, - self.layer_norm_weight, - self.layer_norm_bias, + ln_weight, + ln_bias, fc1_weight, fc1_bias, fc2_weight, @@ -2278,14 +2375,14 @@ def onnx_forward( inp_dtype = inp.dtype fc1_weight, fc2_weight = self._get_weight_tensors() - fc1_bias = self.fc1_bias if self.use_bias else None - fc2_bias = self.fc2_bias if self.use_bias else None + fc1_bias, fc2_bias = self._get_bias_tensors() + ln_weight, ln_bias = self._get_layernorm_weight_and_bias() # layernorm + fp8 cast ln_out, ln_out_return = onnx_layernorm( inp, - self.layer_norm_weight, - self.layer_norm_bias, + ln_weight, + ln_bias, self.eps, self.normalization, self.zero_centered_gamma, @@ -2471,7 +2568,33 @@ def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" - return [self.fc1_weight, self.fc2_weight] + fc1_weight = self.fc1_weight + if isinstance(fc1_weight, DTensor): + fc1_weight = fc1_weight.to_local() + fc2_weight = self.fc2_weight + if isinstance(fc2_weight, DTensor): + fc2_weight = fc2_weight.to_local() + return [fc1_weight, fc2_weight] + + def _get_bias_tensors(self) -> List[torch.Tensor]: + """Get the bias tensors of the module.""" + fc1_bias = self.fc1_bias if self.use_bias else None + if isinstance(fc1_bias, DTensor): + fc1_bias = fc1_bias.to_local() + fc2_bias = self.fc2_bias if self.use_bias else None + if isinstance(fc2_bias, DTensor): + fc2_bias = fc2_bias.to_local() + return [fc1_bias, fc2_bias] + + def _get_layernorm_weight_and_bias(self) -> List[Optional[torch.Tensor]]: + """Get the weight and bias of the layer norm.""" + ln_weight = self.layer_norm_weight + if isinstance(ln_weight, DTensor): + ln_weight = ln_weight.to_local() + ln_bias = self.layer_norm_bias + if isinstance(ln_bias, DTensor): + ln_bias = ln_bias.to_local() + return [ln_weight, ln_bias] def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" @@ -2519,3 +2642,19 @@ def backward_dw(self): del fc1_wgrad del fc1_bias_grad self._trigger_wgrad_accumulation_and_reduce_hooks() + + def _set_tensor_parallel_attributes(self, defer_init=False) -> None: + """Set tensor and sequence parallelism attributes.""" + if not defer_init: + # Set parallel attributes for layer norm parameters + setattr(self.layer_norm_weight, "sequence_parallel", self.sequence_parallel) + if self.normalization != "RMSNorm": + setattr(self.layer_norm_bias, "sequence_parallel", self.sequence_parallel) + + # Set parallel attributes for linear parameters + set_tensor_model_parallel_attributes(self.fc1_weight, True, 0, 1) + set_tensor_model_parallel_attributes(self.fc2_weight, True, 1, 1) + if self.use_bias: + set_tensor_model_parallel_attributes(self.fc1_bias, True, 0, 1) + if self.set_parallel_mode: + setattr(self.fc2_bias, "sequence_parallel", self.sequence_parallel) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 23ad8cacb0..1157d1db59 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -9,6 +9,8 @@ import warnings import torch +from torch.distributed import DeviceMesh +from torch.distributed.tensor import DTensor import transformer_engine_torch as tex @@ -50,6 +52,7 @@ in_fp8_activation_recompute_phase, _fsdp_scatter_tensors, _fsdp_gather_tensors, + _convert_param_to_dtensor_param, ) from ..cpp_extensions import ( general_gemm, @@ -1034,6 +1037,19 @@ class Linear(TransformerEngineBaseModule): used to decide whether this Linear layer is Column Parallel Linear or Row Parallel Linear as described `here `_. When set to ``None``, no communication is performed. + tp_mesh : Optional[DeviceMesh], default = None + A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"]. + Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP. + weight_mesh : Optional[DeviceMesh], default = None + A 1-D DeviceMesh containing a weight-sharding mesh dimension. Only required + when using the FP8 Current (per-tensor) Scaling recipe on sharded DTensor + parameters and if the DTensor DeviceMesh includes dimensions that do not + shard weights, such as in the case of HSDP (DP-Replicate x DP-Shard). + For example: + - device_mesh["dp"] for FSDP. + - device_mesh["dp_cp"] if using CP ranks in FSDP. + - device_mesh["tp"] if using TP. + - device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP. Optimization parameters ----------------------- @@ -1097,6 +1113,8 @@ def __init__( symmetric_ar_type: Optional[str] = None, save_original_input: bool = False, name: Optional[str] = None, + tp_mesh: Optional[DeviceMesh] = None, + weight_mesh: Optional[DeviceMesh] = None, ) -> None: super().__init__(name) @@ -1294,6 +1312,9 @@ def __init__( if with_fp8_params: self.init_fp8_metadata() + if tp_mesh is not None or weight_mesh is not None: + # Apply DeviceMesh and DTensor-related modifications. + self.set_device_mesh(tp_mesh=tp_mesh, weight_mesh=weight_mesh) self.reset_parameters(defer_init=device == "meta") # For RPL, bias has to be added after TP collectives @@ -1321,24 +1342,95 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) + self._set_tensor_parallel_attributes(defer_init=defer_init) + + def set_device_mesh( + self, + tp_mesh: Optional[DeviceMesh] = None, + weight_mesh: Optional[DeviceMesh] = None, + ) -> None: + """ + Set DeviceMesh(s) used for sharding weights and convert main weights into DTensor + depending on the TransformerEngine class to support FSDP-TP sharding with FSDP2. + + TransformerEngine manages tensor parallel mechanics, while DTensor offers seamless + integration with Torch DCP checkpointing. This method should only be invoked when + using DTensor parameters, e.g. when using FSDP2 or DCP. + + When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically + convert them into FSDP-TP strided or non-strided shards depending on the current + sharding dimension and factor of the DTensor. When the sharding dimension of FSDP + matches that of TP, FSDP uses a _StridedShard placement type instead of Shard. + This experimental FSDP-TP logic presides in this FSDP2 initialization function: + ``torch.distributed.fsdp._fully_shard._fsdp_param._init_sharded_param`` + + Parameters + ---------- + tp_mesh : Optional[DeviceMesh] + A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"]. + Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP. + weight_mesh : Optional[DeviceMesh] + A 1-D DeviceMesh containing a weight-sharding mesh dimension. Only required + when using the FP8 Current (per-tensor) Scaling recipe on sharded DTensor + parameters and if the DTensor DeviceMesh includes dimensions that do not + shard weights, such as in the case of HSDP (DP-Replicate x DP-Shard). + For example: + - device_mesh["dp"] for FSDP. + - device_mesh["dp_cp"] if using CP ranks in FSDP. + - device_mesh["tp"] if using TP. + - device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP. + """ + if tp_mesh is not None: + # Validate TP DeviceMesh / Group. Must be consistent with tp_size. + assert ( + tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(), + ( + f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) " + f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})." + ), + ) + # Set the tensor parallel group from the mesh. + self.set_tensor_parallel_group(tp_mesh.get_group()) + + # Construct TP-sharded DTensors. + from torch.distributed.tensor.placement_types import Replicate, Shard - if not defer_init: - # Set parallelism attributes for linear weights for weight in self.weight_names: - set_tensor_model_parallel_attributes( - tensor=getattr(self, weight), - is_parallel=True, - dim=1 if self.parallel_mode == "row" else 0, - stride=1, + param = getattr(self, weight) + placements = (Replicate(),) + if self.parallel_mode == "column": + placements = (Shard(dim=0),) + elif self.parallel_mode == "row": + placements = (Shard(dim=1),) + setattr( + self, + weight, + _convert_param_to_dtensor_param(param, tp_mesh, placements=placements), + ) + for bias in self.bias_names: + param = getattr(self, bias) + placements = (Replicate(),) + if self.parallel_mode == "column": + placements = (Shard(dim=0),) + setattr( + self, + bias, + _convert_param_to_dtensor_param(param, tp_mesh, placements=placements), ) - # Set parallelism attributes for linear biases - if self.use_bias: - for bias in self.bias_names: - if self.parallel_mode == "row": - setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel) - elif self.parallel_mode == "column": - set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) + # Set amax_reduction_group to the FSDP and/or TP sharding mesh + # for per-tensor scaling recipes. Parameters must be registered. + if weight_mesh is not None and self.quantizers["scaling_fwd"]: + for weight in self.weight_names: + # Get fp8_meta_index and associated quantizer. + fp8_meta_index = self.param_init_meta[weight].fp8_meta_index + quantizer = self.quantizers["scaling_fwd"][fp8_meta_index] + if isinstance(quantizer, Float8CurrentScalingQuantizer): + # If not set, will default to DTensor.device_mesh.get_group()! + # MUST be provided when using HSDP (DP-Replicate) or when the + # DeviceMesh includes dimensions that do not shard weights! + quantizer.amax_reduction_group = weight_mesh.get_group() + quantizer.with_amax_reduction = True @no_torch_dynamo() def forward( @@ -1518,7 +1610,12 @@ def _get_debug_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" - unfused_weights = [getattr(self, name) for name in self.weight_names] + unfused_weights = [] + for name in self.weight_names: + weight = getattr(self, name) + if isinstance(weight, DTensor): + weight = weight.to_local() + unfused_weights.append(weight) if any(isinstance(w, QuantizedTensor) for w in unfused_weights): if self.fp8: if len(unfused_weights) != 1: @@ -1533,15 +1630,26 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage unfused_weights = [w.dequantize() for w in unfused_weights] return unfused_weights - def _get_weight_and_bias_tensors(self) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - # Get concatenated weight and bias tensors + def _get_bias_tensors(self) -> List[torch.Tensor]: + """Get the bias tensors of the module.""" + unfused_biases = [] + for name in self.bias_names: + bias = getattr(self, name) + if isinstance(bias, DTensor): + bias = bias.to_local() + unfused_biases.append(bias) + return unfused_biases + + def _get_weight_and_bias_tensors(self) -> List[Optional[torch.Tensor]]: + # Get concatenated weight and bias tensors. unfused_weights = self._get_weight_tensors() weight_tensor = noop_cat(unfused_weights) if self.use_bias: - bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) + unfused_biases = self._get_bias_tensors() + bias_tensor = noop_cat(unfused_biases) else: bias_tensor = None - return weight_tensor, bias_tensor + return [weight_tensor, bias_tensor] def onnx_forward( self, @@ -1668,3 +1776,23 @@ def _get_weight_quantizers(self) -> List[Quantizer]: weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True return [weight_quantizer] + + def _set_tensor_parallel_attributes(self, defer_init=False) -> None: + """Set tensor and sequence parallelism attributes.""" + if not defer_init: + # Set parallelism attributes for linear weights + for weight in self.weight_names: + set_tensor_model_parallel_attributes( + tensor=getattr(self, weight), + is_parallel=True, + dim=1 if self.parallel_mode == "row" else 0, + stride=1, + ) + + # Set parallelism attributes for linear biases + if self.use_bias: + for bias in self.bias_names: + if self.parallel_mode == "row": + setattr(getattr(self, bias), "sequence_parallel", self.sequence_parallel) + elif self.parallel_mode == "column": + set_tensor_model_parallel_attributes(getattr(self, bias), True, 0, 1) diff --git a/transformer_engine/pytorch/module/rmsnorm.py b/transformer_engine/pytorch/module/rmsnorm.py index ace4be31de..e02dbf6a29 100644 --- a/transformer_engine/pytorch/module/rmsnorm.py +++ b/transformer_engine/pytorch/module/rmsnorm.py @@ -7,6 +7,7 @@ from typing import Iterable, Optional, Union import torch +from torch.distributed import DeviceMesh from transformer_engine.pytorch.ops import RMSNorm as _RMSNormOp @@ -137,6 +138,45 @@ def reset_parameters(self, defer_init: Optional[bool] = None) -> None: if self.sequence_parallel is not None: self.weight.sequence_parallel = self.sequence_parallel + def set_device_mesh( + self, + tp_mesh: Optional[DeviceMesh] = None, + weight_mesh: Optional[DeviceMesh] = None, + ) -> None: + """ + Set DeviceMesh(s) used for sharding weights and convert main weights into DTensor + depending on the TransformerEngine class to support FSDP-TP sharding with FSDP2. + + TransformerEngine manages tensor parallel mechanics, while DTensor offers seamless + integration with Torch DCP checkpointing. This method should only be invoked when + using DTensor parameters, e.g. when using FSDP2 or DCP. + + When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically + convert them into FSDP-TP strided or non-strided shards depending on the current + sharding dimension and factor of the DTensor. When the sharding dimension of FSDP + matches that of TP, FSDP uses a _StridedShard placement type instead of Shard. + This experimental FSDP-TP logic presides in this FSDP2 initialization function: + ``torch.distributed.fsdp._fully_shard._fsdp_param._init_sharded_param`` + + Parameters + ---------- + tp_mesh : Optional[DeviceMesh] + A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"]. + Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP. + weight_mesh : Optional[DeviceMesh] + Quantized DTensor parameters are currently not supported for FusibleOperation(s), + and this mesh is not used. + """ + if tp_mesh is not None: + # Construct TP-Replicate DTensors. Used to shim non-TP parameters for compatibility + # with DTensor parameters in TP layers to support DTensor operations. + from transformer_engine.pytorch.distributed import _convert_param_to_dtensor_param + from torch.distributed.tensor.placement_types import Replicate + + self.weight = _convert_param_to_dtensor_param( + self.weight, tp_mesh, placements=(Replicate(),) + ) + @property def fwd_rmsnorm_sm_margin(self) -> int: """Shim for backward compatibility""" diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 3fda5145c6..d191d41bcc 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -185,6 +185,8 @@ def op_forward( # Check tensor dims weight = self.weight + if isinstance(weight, DTensor): + weight = weight.to_local() weight_dims = tuple(weight.size()) input_dims = tuple(input_.size()) if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims: @@ -192,13 +194,16 @@ def op_forward( f"Input tensor (shape={input_dims}) " f"and weight tensor (shape={weight_dims}) are not compatible" ) + bias = self.bias + if isinstance(bias, DTensor): + bias = bias.to_local() # Check input tensors inner_dim = math.prod(weight_dims) dtype = maybe_autocast_dtype(default_dtype=weight.dtype) x = maybe_dequantize(input_.contiguous(), dtype).view((-1, inner_dim)) - w = maybe_dequantize(self.weight, dtype).view((inner_dim,)) - b = maybe_dequantize(self.bias, dtype).view((inner_dim,)) + w = maybe_dequantize(weight, dtype).view((inner_dim,)) + b = maybe_dequantize(bias, dtype).view((inner_dim,)) # Compute layer norm sm_margin = self._sm_margins["forward" if ctx.requires_grad else "inference"] @@ -235,13 +240,16 @@ def op_backward( x, means, rstdevs = ctx.saved_tensors # Tensor dims - weight_dims = self.weight.size() + weight = self.weight + if isinstance(weight, DTensor): + weight = weight.to_local() + weight_dims = weight.size() inner_dim = math.prod(weight_dims) # Check input tensors dtype = ctx.dtype dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size()) - w = maybe_dequantize(self.weight, dtype).view((inner_dim,)) + w = maybe_dequantize(weight, dtype).view((inner_dim,)) # Compute layer norm backward pass dx, dw, db = layernorm_bwd( @@ -262,7 +270,25 @@ def op_backward( # Reshape results grad_input = dx.view(grad_output.size()) grad_weight = dw.view(weight_dims) + # If the main weight was a DTensor, convert the wgrad to a DTensor. + if isinstance(self.weight, DTensor): + grad_weight = DTensor.from_local( + grad_weight, + device_mesh=self.weight.device_mesh, + placements=self.weight.placements, + shape=self.weight.size(), + stride=self.weight.stride(), + ) grad_bias = db.view(weight_dims) + # If the main weight was a DTensor, convert the wgrad to a DTensor. + if isinstance(self.bias, DTensor): + grad_bias = DTensor.from_local( + grad_bias, + device_mesh=self.bias.device_mesh, + placements=self.bias.placements, + shape=self.bias.size(), + stride=self.bias.stride(), + ) return grad_input, (grad_weight, grad_bias) def op_onnx_forward( diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 1d8d8be971..2a36a3d3a9 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -11,6 +11,7 @@ from typing import Optional import torch +from torch.distributed.tensor import DTensor from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd from ...constants import TE_DType @@ -168,6 +169,8 @@ def op_forward( # Check tensor dims weight = self.weight + if isinstance(weight, DTensor): + weight = weight.to_local() weight_dims = tuple(weight.size()) input_dims = tuple(input_.size()) if len(input_dims) < len(weight_dims) or input_dims[-len(weight_dims) :] != weight_dims: @@ -180,7 +183,7 @@ def op_forward( inner_dim = math.prod(weight_dims) dtype = maybe_autocast_dtype(default_dtype=weight.dtype) x = maybe_dequantize(input_.contiguous(), dtype).view((-1, inner_dim)) - w = maybe_dequantize(self.weight, dtype).view((inner_dim,)) + w = maybe_dequantize(weight, dtype).view((inner_dim,)) # Compute RMSNorm sm_margin = self._sm_margins["forward" if ctx.requires_grad else "inference"] @@ -216,13 +219,16 @@ def op_backward( x, rstdevs = ctx.saved_tensors # Tensor dims - weight_dims = self.weight.size() + weight = self.weight + if isinstance(weight, DTensor): + weight = weight.to_local() + weight_dims = weight.size() inner_dim = math.prod(weight_dims) # Check input tensors dtype = ctx.dtype dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size()) - w = maybe_dequantize(self.weight, dtype).view((inner_dim,)) + w = maybe_dequantize(weight, dtype).view((inner_dim,)) # Compute RMSNorm backward pass dx, dw = rmsnorm_bwd( @@ -241,6 +247,15 @@ def op_backward( # Reshape results grad_input = dx.view(grad_output.size()) grad_weight = dw.view(weight_dims) + # If the main weight was a DTensor, convert the wgrad to a DTensor. + if isinstance(self.weight, DTensor): + grad_weight = DTensor.from_local( + grad_weight, + device_mesh=self.weight.device_mesh, + placements=self.weight.placements, + shape=self.weight.size(), + stride=self.weight.stride(), + ) return grad_input, (grad_weight,) def op_onnx_forward( diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 55bca49af3..54879f585f 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -8,6 +8,7 @@ import warnings import torch from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState +from torch.distributed.tensor import DTensor import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType @@ -895,6 +896,16 @@ def fsdp_post_all_gather( (data,) = all_gather_outputs (fp8_scale_inv, rowwise_usage, columnwise_usage, fp8_dtype) = metadata orig_shape = data.size() + + # If out is a DTensor from a previously un-sharded + # AG buffer, convert to local Tensor. + # FIXME(@cspades): FP8 parameters currently are not + # compatible with DCP checkpointing. + if isinstance(out, DTensor): + # out.to_local() is not supported with Torch Dispatch, + # for quantized tensors with _transpose usage. + out = out._local_tensor + # Quantizer has only columnwise usage set for backward pass # In Blackwell+ architectures, transpose is not needed at all, # even if columnwise usage is set. and is going to be handled diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 41d6c87f2b..00275bf0c2 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -11,6 +11,7 @@ import torch from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState +from torch.distributed.tensor import DTensor import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType @@ -728,6 +729,15 @@ def fsdp_post_all_gather( columnwise_scale_inv, (0, 0, 0, pad_dim0) ) + # If out is a DTensor from a previously un-sharded + # AG buffer, convert to local Tensor. + # FIXME(@cspades): FP8 parameters currently are not + # compatible with DCP checkpointing. + if isinstance(out, DTensor): + # out.to_local() is not supported with Torch Dispatch, + # for quantized tensors with _transpose usage. + out = out._local_tensor + if out is not None: out._rowwise_data = rowwise_data out._rowwise_scale_inv = rowwise_scale_inv diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index cf7ce5e1a4..33586f34ab 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -9,6 +9,7 @@ from typing import Callable, List, Optional, Tuple, Union import torch +from torch.distributed import DeviceMesh from transformer_engine.pytorch.torch_version import torch_version from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm @@ -248,6 +249,19 @@ class TransformerLayer(torch.nn.Module): :meth:`set_tensor_parallel_group` method on the initialized module before the forward pass to supply the tensor parallel group needed for tensor and sequence parallel collectives. + tp_mesh : Optional[DeviceMesh], default = None + A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"]. + Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP. + weight_mesh : Optional[DeviceMesh], default = None + A 1-D DeviceMesh containing a weight-sharding mesh dimension. Only required + when using the FP8 Current (per-tensor) Scaling recipe on sharded DTensor + parameters and if the DTensor DeviceMesh includes dimensions that do not + shard weights, such as in the case of HSDP (DP-Replicate x DP-Shard). + For example: + - device_mesh["dp"] for FSDP. + - device_mesh["dp_cp"] if using CP ranks in FSDP. + - device_mesh["tp"] if using TP. + - device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP. Optimization parameters ----------------------- @@ -350,6 +364,8 @@ def __init__( qk_norm_eps: float = 1e-6, qk_norm_before_rope: bool = False, softmax_type: str = "vanilla", + tp_mesh: Optional[DeviceMesh] = None, + weight_mesh: Optional[DeviceMesh] = None, ) -> None: super().__init__() @@ -461,6 +477,8 @@ def __init__( qk_norm_eps=qk_norm_eps, qk_norm_before_rope=qk_norm_before_rope, name=self.name + ".self_attention" if self.name is not None else None, + tp_mesh=tp_mesh, + weight_mesh=weight_mesh, ) if layer_type == "decoder": @@ -478,6 +496,8 @@ def __init__( qk_norm_eps=qk_norm_eps, qk_norm_before_rope=qk_norm_before_rope, name=self.name + ".inter_attention" if self.name is not None else None, + tp_mesh=tp_mesh, + weight_mesh=weight_mesh, ) # LayerNorm -> activation(Linear + Bias) -> Linear @@ -514,6 +534,8 @@ def __init__( normalization=normalization, device=device, name=self.name + ".layernorm_mlp" if self.name is not None else None, + tp_mesh=tp_mesh, + weight_mesh=weight_mesh, ) self.hidden_dropout = hidden_dropout @@ -544,6 +566,9 @@ def __init__( zero_centered_gamma=zero_centered_gamma, device=device, ) + if tp_mesh is not None or weight_mesh is not None: + if hasattr(self.layernorm, "set_device_mesh"): + self.layernorm.set_device_mesh(tp_mesh=tp_mesh, weight_mesh=weight_mesh) def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: """ @@ -562,6 +587,62 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N if hasattr(child, "set_tensor_parallel_group"): child.set_tensor_parallel_group(tp_group) + def set_device_mesh( + self, + tp_mesh: Optional[DeviceMesh] = None, + weight_mesh: Optional[DeviceMesh] = None, + ) -> None: + """ + Set DeviceMesh(s) used for sharding weights and convert main weights into DTensor + depending on the TransformerEngine class to support FSDP-TP sharding with FSDP2. + + TransformerEngine manages tensor parallel mechanics, while DTensor offers seamless + integration with Torch DCP checkpointing. This method should only be invoked when + using DTensor parameters, e.g. when using FSDP2 or DCP. + + When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically + convert them into FSDP-TP strided or non-strided shards depending on the current + sharding dimension and factor of the DTensor. When the sharding dimension of FSDP + matches that of TP, FSDP uses a _StridedShard placement type instead of Shard. + This experimental FSDP-TP logic presides in this FSDP2 initialization function: + ``torch.distributed.fsdp._fully_shard._fsdp_param._init_sharded_param`` + + Parameters + ---------- + tp_mesh : Optional[DeviceMesh] + A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"]. + Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP. + weight_mesh : Optional[DeviceMesh] + A 1-D DeviceMesh containing a weight-sharding mesh dimension. Only required + when using the FP8 Current (per-tensor) Scaling recipe on sharded DTensor + parameters and if the DTensor DeviceMesh includes dimensions that do not + shard weights, such as in the case of HSDP (DP-Replicate x DP-Shard). + For example: + - device_mesh["dp"] for FSDP. + - device_mesh["dp_cp"] if using CP ranks in FSDP. + - device_mesh["tp"] if using TP. + - device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP. + """ + if tp_mesh is not None: + # Validate TP DeviceMesh / Group. Must be consistent with tp_size. + assert ( + tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(), + ( + f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) " + f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})." + ), + ) + # Set the tensor parallel group from the mesh. + self.set_tensor_parallel_group(tp_mesh.get_group()) + + if tp_mesh is not None or weight_mesh is not None: + # Iterate through child sub-modules without deep recursion. + # Automatically detects TransformerEngine TP modules and + # the capability to call this method at any level. + for name, child in self.named_children(): + if hasattr(child, "set_device_mesh"): + child.set_device_mesh(tp_mesh, weight_mesh) + def reset_fp8_meta_tensors(self) -> None: """Set TP group""" # Deep iterate but skip self to avoid infinite recursion.