Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import torch
from torch.nn.parameter import Parameter
from torch.distributed import DeviceMesh
from torch.distributed.tensor import DTensor

import transformer_engine_torch as tex
from transformer_engine.common.recipe import (
Expand Down Expand Up @@ -44,6 +46,7 @@
set_all_rng_states,
CudaRNGStatesTracker,
graph_safe_rng_available,
_convert_param_to_dtensor_param,
)
from transformer_engine.pytorch.jit import no_torch_dynamo
from transformer_engine.pytorch.graph import is_graph_capturing
Expand Down Expand Up @@ -298,6 +301,11 @@ class DotProductAttention(TransformerEngineBaseModule):
``ProcessGroup`` is for :attr:`cp_comm_type` of ``"p2p"``, ``"all_gather"``, and ``"a2a"``.
``List[ProcessGroup]`` is for :attr:`cp_comm_type` of ``"a2a+p2p"``, where :attr:`cp_group[0]`
and :attr:`cp_group[1]` are for ``"a2a"`` and ``"p2p"`` communications respectively.
tp_mesh : Optional[DeviceMesh], default = None
A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"].
Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP.
weight_mesh : Optional[DeviceMesh]
Not used for DotProductAttention as there are no quantized weights.
cp_global_ranks : list of global rank IDs, default = None
global rank IDs of GPUs that are in ``cp_group``.
cp_stream : CUDA stream, default = None
Expand Down Expand Up @@ -343,6 +351,8 @@ def __init__(
softmax_scale: Optional[float] = None,
softmax_type: str = "vanilla",
return_max_logit: Optional[bool] = False,
tp_mesh: Optional[DeviceMesh] = None,
weight_mesh: Optional[DeviceMesh] = None,
) -> None:
super().__init__()

Expand Down Expand Up @@ -477,6 +487,10 @@ def __init__(
return_max_logit=self.return_max_logit,
)

if tp_mesh is not None or weight_mesh is not None:
# Apply DeviceMesh and DTensor-related modifications.
self.set_device_mesh(tp_mesh=tp_mesh, weight_mesh=weight_mesh)

def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument
"""
Temporarily remove core_attention._extra_state as a missing key
Expand Down Expand Up @@ -533,6 +547,54 @@ def custom_forward(*input_args, **input_kwargs):

return hidden_states

def set_device_mesh(
self,
tp_mesh: Optional[DeviceMesh] = None,
weight_mesh: Optional[DeviceMesh] = None,
) -> None:
"""
Set DeviceMesh(s) used for sharding weights and convert main weights into DTensor
depending on the TransformerEngine class to support FSDP-TP sharding with FSDP2.

TransformerEngine manages tensor parallel mechanics, while DTensor offers seamless
integration with Torch DCP checkpointing. This method should only be invoked when
using DTensor parameters, e.g. when using FSDP2 or DCP.

When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically
convert them into FSDP-TP strided or non-strided shards depending on the current
sharding dimension and factor of the DTensor. When the sharding dimension of FSDP
matches that of TP, FSDP uses a _StridedShard placement type instead of Shard.
This experimental FSDP-TP logic presides in this FSDP2 initialization function:
``torch.distributed.fsdp._fully_shard._fsdp_param._init_sharded_param``

Parameters
----------
tp_mesh : Optional[DeviceMesh]
A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"].
Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP.
weight_mesh : Optional[DeviceMesh]
Not used for DotProductAttention as there are no quantized weights.
"""
if tp_mesh is not None:
# Validate TP DeviceMesh / Group. Must be consistent with tp_size.
assert (
tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(),
(
f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) "
f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})."
),
)
# Set the tensor parallel group from the mesh.
self.set_tensor_parallel_group(tp_mesh.get_group())

# Construct TP-sharded DTensors.
if self.softmax_type == "learnable":
from torch.distributed.tensor.placement_types import Shard

self.softmax_offset = _convert_param_to_dtensor_param(
self.softmax_offset, tp_mesh, placements=(Shard(dim=0),)
)

def set_context_parallel_group(
self,
cp_group: Union[dist_group_type, List[dist_group_type], None],
Expand Down Expand Up @@ -802,6 +864,17 @@ def set_meta_tensor(self, fwd: bool, recipe: Union[Recipe, List[Recipe]]) -> Non
for recipe_state in recipe_states:
self.quantizers[fp8_meta_tensor_key].extend(recipe_state.make_quantizers())

def _get_softmax_offset(self) -> torch.Tensor:
"""Get the softmax offset."""
softmax_offset = (
self.softmax_offset.reshape(1, -1, 1, 1).to(torch.float32)
if self.softmax_offset is not None
else None
)
if isinstance(softmax_offset, DTensor):
softmax_offset = softmax_offset.to_local()
return softmax_offset

@no_torch_dynamo(recursive=False)
def forward(
self,
Expand Down Expand Up @@ -1434,11 +1507,7 @@ def forward(
)

# run attention
softmax_offset = (
self.softmax_offset.reshape(1, -1, 1, 1).to(torch.float32)
if self.softmax_offset is not None
else None
)
softmax_offset = self._get_softmax_offset()

if use_flash_attention:
if core_attention_bias_type == "alibi":
Expand Down
93 changes: 93 additions & 0 deletions transformer_engine/pytorch/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import collections
from typing import Callable, List, Optional, Tuple, Union
import torch
from torch.distributed import DeviceMesh

from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
Expand Down Expand Up @@ -191,6 +192,19 @@ class MultiheadAttention(torch.nn.Module):
``set_tensor_parallel_group(tp_group)`` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives.
tp_mesh : Optional[DeviceMesh], default = None
A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"].
Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP.
weight_mesh : Optional[DeviceMesh], default = None
A 1-D DeviceMesh containing a weight-sharding mesh dimension. Only required
when using the FP8 Current (per-tensor) Scaling recipe on sharded DTensor
parameters and if the DTensor DeviceMesh includes dimensions that do not
shard weights, such as in the case of HSDP (DP-Replicate x DP-Shard).
For example:
- device_mesh["dp"] for FSDP.
- device_mesh["dp_cp"] if using CP ranks in FSDP.
- device_mesh["tp"] if using TP.
- device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP.

Optimization parameters
-----------------------
Expand Down Expand Up @@ -286,6 +300,8 @@ def __init__(
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
softmax_type: str = "vanilla",
tp_mesh: Optional[DeviceMesh] = None,
weight_mesh: Optional[DeviceMesh] = None,
) -> None:
super().__init__()

Expand Down Expand Up @@ -357,6 +373,13 @@ def __init__(
self.q_norm, self.k_norm = self._create_qk_norm_modules(
qk_norm_type, qk_norm_eps, device, seq_length, micro_batch_size
)
if tp_mesh is not None or weight_mesh is not None:
# Apply DeviceMesh and DTensor-related modifications.
# Only necessary for trainable weighted norms.
if hasattr(self.q_norm, "set_device_mesh"):
self.q_norm.set_device_mesh(tp_mesh=tp_mesh, weight_mesh=weight_mesh)
if hasattr(self.k_norm, "set_device_mesh"):
self.k_norm.set_device_mesh(tp_mesh=tp_mesh, weight_mesh=weight_mesh)

qkv_parallel_mode = "column" if set_parallel_mode else None

Expand Down Expand Up @@ -389,6 +412,8 @@ def __init__(
normalization=normalization,
ub_name="qkv",
name=name + ".layernorm_linear_qkv" if name is not None else None,
tp_mesh=tp_mesh,
weight_mesh=weight_mesh,
**common_gemm_kwargs,
)
else:
Expand All @@ -401,6 +426,8 @@ def __init__(
parallel_mode=qkv_parallel_mode,
parameters_split=parameters_split,
name=name + ".linear_qkv" if name is not None else None,
tp_mesh=tp_mesh,
weight_mesh=weight_mesh,
**common_gemm_kwargs,
)
elif self.attention_type == "cross":
Expand All @@ -423,6 +450,8 @@ def __init__(
normalization=normalization,
ub_name="qkv",
name=name + ".layernorm_linear_q" if name is not None else None,
tp_mesh=tp_mesh,
weight_mesh=weight_mesh,
**common_gemm_kwargs,
)
else:
Expand All @@ -433,6 +462,8 @@ def __init__(
bias=bias,
return_bias=False,
parallel_mode=qkv_parallel_mode,
tp_mesh=tp_mesh,
weight_mesh=weight_mesh,
**common_gemm_kwargs,
)
self.key_value = Linear(
Expand All @@ -444,6 +475,8 @@ def __init__(
parallel_mode=qkv_parallel_mode,
parameters_split=("key", "value") if not fuse_qkv_params else None,
name=name + ".linear_kv" if name is not None else None,
tp_mesh=tp_mesh,
weight_mesh=weight_mesh,
**common_gemm_kwargs,
)

Expand All @@ -461,6 +494,8 @@ def __init__(
layer_number=self.layer_number,
attention_type=self.attention_type,
softmax_type=self.softmax_type,
tp_mesh=tp_mesh,
weight_mesh=weight_mesh,
)

# Linear
Expand All @@ -475,6 +510,8 @@ def __init__(
ub_overlap_ag=ub_overlap_ag,
ub_name="proj",
name=name + ".proj" if name is not None else None,
tp_mesh=tp_mesh,
weight_mesh=weight_mesh,
**common_gemm_kwargs,
)

Expand Down Expand Up @@ -562,6 +599,62 @@ def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> N
"""
self.tp_group = tp_group

def set_device_mesh(
self,
tp_mesh: Optional[DeviceMesh] = None,
weight_mesh: Optional[DeviceMesh] = None,
) -> None:
"""
Set DeviceMesh(s) used for sharding weights and convert main weights into DTensor
depending on the TransformerEngine class to support FSDP-TP sharding with FSDP2.

TransformerEngine manages tensor parallel mechanics, while DTensor offers seamless
integration with Torch DCP checkpointing. This method should only be invoked when
using DTensor parameters, e.g. when using FSDP2 or DCP.

When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically
convert them into FSDP-TP strided or non-strided shards depending on the current
sharding dimension and factor of the DTensor. When the sharding dimension of FSDP
matches that of TP, FSDP uses a _StridedShard placement type instead of Shard.
This experimental FSDP-TP logic presides in this FSDP2 initialization function:
``torch.distributed.fsdp._fully_shard._fsdp_param._init_sharded_param``

Parameters
----------
tp_mesh : Optional[DeviceMesh]
A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"].
Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP.
weight_mesh : Optional[DeviceMesh]
A 1-D DeviceMesh containing a weight-sharding mesh dimension. Only required
when using the FP8 Current (per-tensor) Scaling recipe on sharded DTensor
parameters and if the DTensor DeviceMesh includes dimensions that do not
shard weights, such as in the case of HSDP (DP-Replicate x DP-Shard).
For example:
- device_mesh["dp"] for FSDP.
- device_mesh["dp_cp"] if using CP ranks in FSDP.
- device_mesh["tp"] if using TP.
- device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP.
"""
if tp_mesh is not None:
# Validate TP DeviceMesh / Group. Must be consistent with tp_size.
assert (
tp_mesh.ndim == 1 and self.tp_size == tp_mesh.size(),
(
f"TransformerEngine {self.__class__.__name__} TP init size ({self.tp_size}) "
f"does not match the size of the provided TP DeviceMesh ({tp_mesh.size()})."
),
)
# Set the tensor parallel group from the mesh.
self.set_tensor_parallel_group(tp_mesh.get_group())

if tp_mesh is not None or weight_mesh is not None:
# Iterate through child sub-modules without deep recursion.
# Automatically detects TransformerEngine TP modules and
# the capability to call this method at any level.
for name, child in self.named_children():
if hasattr(child, "set_device_mesh"):
child.set_device_mesh(tp_mesh, weight_mesh)

def set_context_parallel_group(
self,
cp_group: Union[dist_group_type, List[dist_group_type], None],
Expand Down
34 changes: 34 additions & 0 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1934,6 +1934,40 @@ def _get_module_fsdp_state(module):
return fsdp_state


def _convert_param_to_dtensor_param(
param: torch.nn.Parameter,
device_mesh: torch.distributed.DeviceMesh,
placements: Tuple[torch.distributed.tensor.placement_types.Placement],
shape: Optional[torch.Size] = None,
stride: Optional[Tuple[int]] = None,
):
"""Convert the parameter into a DTensor."""
from torch.distributed.tensor import DTensor

# If the parameter is already a DTensor, extract local Tensor.
# We overwrite the original DTensor's distributed configuration.
param_tensor = param
if isinstance(param, DTensor):
param_tensor = param.to_local()
# Convert the parameter to a DTensor.
new_param = torch.nn.Parameter(
DTensor.from_local(
param_tensor,
device_mesh,
placements=placements,
shape=shape,
stride=stride,
)
)
# Inherit attributes of the original Parameter.
# For example, "param_init_meta" or "tensor_model_parallel".
for key, val in param.__dict__.items():
if not hasattr(new_param, key):
# Set the original attribute.
setattr(new_param, key, val)
return new_param


def _fsdp_scatter_tensors(
fsdp_group: dist_group_type,
*tensors: torch.Tensor,
Expand Down
Loading