diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh index f2ef33da46..3c1f2ba1fb 100644 --- a/examples/jax/encoder/run_test_multiprocessing_encoder.sh +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -11,10 +11,6 @@ TEST_CASES=( "test_te_current_scaling_fp8" "test_te_mxfp8" "test_te_nvfp4" -"test_te_bf16_shardy" -"test_te_delayed_scaling_fp8_shardy" -"test_te_current_scaling_fp8_shardy" -"test_te_nvfp4_shardy" ) : ${TE_PATH:=/opt/transformerengine} diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 73b93798a0..4400485f26 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -239,7 +239,6 @@ def check_fp8(state, var_collect, inputs, masks, labels): def train_and_evaluate(args): """Execute model training and evaluation loop.""" print(args) - jax.config.update("jax_use_shardy_partitioner", args.enable_shardy) train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) @@ -474,9 +473,6 @@ def encoder_parser(args): parser.add_argument( "--enable-sp", action="store_true", default=False, help="Enable sequence parallelism." ) - parser.add_argument( - "--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)." - ) return parser.parse_args(args) @@ -559,70 +555,6 @@ def test_te_nvfp4_with_sp(self): actual = train_and_evaluate(self.args) assert actual[0] < 0.40 and actual[1] > 0.82 - @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") - def test_te_bf16_shardy(self): - """Test Transformer Engine with BF16""" - self.args.enable_shardy = True - actual = train_and_evaluate(self.args) - assert actual[0] < 0.36 and actual[1] > 0.84 - - @unittest.skipIf(not is_fp8_supported, fp8_reason) - def test_te_delayed_scaling_fp8_shardy(self): - """Test Transformer Engine with DelayedScaling FP8""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "DelayedScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.362 and actual[1] > 0.84 - - @unittest.skipIf(not is_fp8_supported, fp8_reason) - def test_te_delayed_scaling_fp8_with_sp_shardy(self): - """Test Transformer Engine with DelayedScaling FP8 + SP""" - self.args.enable_shardy = True - self.args.enable_sp = True - self.args.use_fp8 = True - self.args.fp8_recipe = "DelayedScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.362 and actual[1] > 0.84 - - @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - def test_te_mxfp8_shardy(self): - """Test Transformer Engine with MXFP8""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "MXFP8BlockScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.36 and actual[1] > 0.84 - - @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) - def test_te_nvfp4_shardy(self): - """Test Transformer Engine with NVFP4""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "NVFP4BlockScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.40 and actual[1] > 0.82 - - @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - def test_te_mxfp8_with_sp_shardy(self): - """Test Transformer Engine with MXFP8 + SP""" - self.args.enable_shardy = True - self.args.enable_sp = True - self.args.use_fp8 = True - self.args.fp8_recipe = "MXFP8BlockScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.36 and actual[1] > 0.84 - - @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) - def test_te_nvfp4_with_sp_shardy(self): - """Test Transformer Engine with NVFP4""" - self.args.enable_shardy = True - self.args.enable_sp = True - self.args.use_fp8 = True - self.args.fp8_recipe = "NVFP4BlockScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.40 and actual[1] > 0.82 - if __name__ == "__main__": train_and_evaluate(encoder_parser(None)) diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 22a89cc0a9..e2edc589b9 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -249,7 +249,6 @@ def replace_params(x): def train_and_evaluate(args): """Execute model training and evaluation loop.""" print(args) - jax.config.update("jax_use_shardy_partitioner", args.enable_shardy) train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) num_gpu = jax.local_device_count() @@ -438,9 +437,6 @@ def encoder_parser(args): default="DelayedScaling", help="Use FP8 recipe (default: DelayedScaling)", ) - parser.add_argument( - "--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)." - ) return parser.parse_args(args) @@ -494,49 +490,6 @@ def test_te_nvfp4(self): actual = train_and_evaluate(self.args) assert actual[0] < 0.52 and actual[1] > 0.74 - @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") - def test_te_bf16_shardy(self): - """Test Transformer Engine with BF16""" - self.args.enable_shardy = True - actual = train_and_evaluate(self.args) - assert actual[0] < 0.51 and actual[1] > 0.75 - - @unittest.skipIf(not is_fp8_supported, fp8_reason) - def test_te_delayed_scaling_fp8_shardy(self): - """Test Transformer Engine with DelayedScaling FP8""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "DelayedScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.51 and actual[1] > 0.75 - - @unittest.skipIf(not is_fp8_supported, fp8_reason) - def test_te_current_scaling_fp8_shardy(self): - """Test Transformer Engine with CurrentScaling FP8""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "Float8CurrentScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.51 and actual[1] > 0.749 - - @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - def test_te_mxfp8_shardy(self): - """Test Transformer Engine with MXFP8""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "MXFP8BlockScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.51 and actual[1] > 0.75 - - @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) - def test_te_nvfp4_shardy(self): - """Test Transformer Engine with NVFP4""" - self.args.enable_shardy = True - self.args.use_fp8 = True - self.args.fp8_recipe = "NVFP4BlockScaling" - actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 - if __name__ == "__main__": train_and_evaluate(encoder_parser(None)) diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 0166b60acd..344e7d618b 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -359,7 +359,6 @@ def replace_params(x): def train_and_evaluate(args): """Execute model training and evaluation loop.""" print(args) - jax.config.update("jax_use_shardy_partitioner", args.enable_shardy) if args.process_id == 0: nltk.download("punkt_tab") @@ -605,9 +604,6 @@ def encoder_parser(args): default=0, help="the ID number of the current process (default: 0)", ) - parser.add_argument( - "--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)." - ) return parser.parse_args(args) @@ -616,7 +612,7 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False): + def exec(self, use_fp8, fp8_recipe): """Run 5 epochs for testing""" args = encoder_parser(["--epochs", "5"]) @@ -632,7 +628,6 @@ def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False): args.num_process = num_gpu args.process_id = self.process_id args.fp8_recipe = fp8_recipe - args.enable_shardy = enable_shardy return train_and_evaluate(args) @@ -674,44 +669,6 @@ def test_te_nvfp4(self): result = self.exec(True, "NVFP4BlockScaling") assert result[0] < 0.451 and result[1] > 0.787 - @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") - def test_te_bf16_shardy(self): - """Test Transformer Engine with BF16""" - result = self.exec(False, None, enable_shardy=True) - assert result[0] < 0.43 and result[1] > 0.80 - - @unittest.skipIf( - not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8" - ) - def test_te_delayed_scaling_fp8_shardy(self): - """Test Transformer Engine with DelayedScaling FP8""" - result = self.exec(True, "DelayedScaling", enable_shardy=True) - assert result[0] < 0.43 and result[1] > 0.80 - - @unittest.skipIf( - not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" - ) - def test_te_current_scaling_fp8_shardy(self): - """Test Transformer Engine with CurrentScaling FP8""" - result = self.exec(True, "Float8CurrentScaling", enable_shardy=True) - assert result[0] < 0.432 and result[1] > 0.80 - - @unittest.skipIf( - not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" - ) - def test_te_mxfp8_shardy(self): - """Test Transformer Engine with MXFP8""" - result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True) - assert result[0] < 0.43 and result[1] > 0.80 - - @unittest.skipIf( - not is_nvfp4_supported(), "Device compute capability 10.0+ is required for NVFP4" - ) - def test_te_nvfp4_shardy(self): - """Test Transformer Engine with NVFP4""" - result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True) - assert result[0] < 0.451 and result[1] > 0.787 - if __name__ == "__main__": train_and_evaluate(encoder_parser(None)) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index d5ebe9f261..50c5de1db7 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -68,9 +68,7 @@ def impl_test_self_attn( attn_mask_type, dtype, softmax_type, - use_shardy, ): - jax.config.update("jax_use_shardy_partitioner", use_shardy) dropout_prob = 0.0 is_training = True batch, seqlen, num_head, hidden = data_shape @@ -178,48 +176,6 @@ def test_self_attn( attn_mask_type, dtype, softmax_type, - use_shardy=False, - ) - - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize( - "attn_bias_type, bias_shape", - [ - pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), - pytest.param(AttnBiasType.PRE_SCALE_BIAS, BiasShape._1HSS, id="PRE_SCALE_BIAS-1HSS"), - ], - ) - @pytest.mark.parametrize( - "softmax_type", - [ - pytest.param(AttnSoftmaxType.VANILLA_SOFTMAX, id="VANILLA_SOFTMAX"), - pytest.param(AttnSoftmaxType.OFF_BY_ONE_SOFTMAX, id="OFF_BY_ONE_SOFTMAX"), - pytest.param(AttnSoftmaxType.LEARNABLE_SOFTMAX, id="LEARNABLE_SOFTMAX"), - ], - ) - def test_self_attn_shardy( - self, - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - attn_bias_type, - bias_shape, - softmax_type, - ): - data_shape = (32, 512, 12, 64) - self.impl_test_self_attn( - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - data_shape, - attn_bias_type, - bias_shape, - AttnMaskType.PADDING_MASK, - jnp.bfloat16, - softmax_type, - use_shardy=True, ) @@ -348,7 +304,6 @@ def impl_test_context_parallel_attn( qkv_layout, load_balanced, cp_strategy, - use_shardy, use_scan_ring=False, window_size=None, stripe_size=None, @@ -366,8 +321,6 @@ def impl_test_context_parallel_attn( os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "1" else: os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] = "0" - - jax.config.update("jax_use_shardy_partitioner", use_shardy) attn_bias_type = AttnBiasType.NO_BIAS bias_shape = None dropout_prob = 0.0 @@ -452,45 +405,6 @@ def check_has_backend_for_mask(mask_type): runner.test_backward() del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] - @pytest_parametrize_wrapper( - "device_count,mesh_shape,mesh_axes,mesh_resource", - generate_context_parallel_configs_for_attn(), - ) - @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES) - @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) - @pytest.mark.parametrize( - "qkv_layout, attn_mask_type", - DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, - ) - def test_context_parallel_allgather_attn_shardy( - self, - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - data_shape, - attn_mask_type, - dtype, - qkv_layout, - ): - if qkv_layout.is_thd(): - pytest.skip("Only BSHD layout is supported for CP + AG + Dual chunk attention") - kv_groups = 8 - self.impl_test_context_parallel_attn( - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - data_shape, - kv_groups, - attn_mask_type, - dtype, - qkv_layout, - load_balanced=True, - cp_strategy=CPStrategy.ALL_GATHER, - use_shardy=True, - ) - @pytest_parametrize_wrapper( "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs_for_attn(), @@ -551,7 +465,6 @@ def test_context_parallel_allgather_striped_attn( qkv_layout, load_balanced, CPStrategy.ALL_GATHER, - use_shardy=False, window_size=window_size, stripe_size=stripe_size, num_segments_per_seq=num_segments_per_seq, @@ -599,7 +512,6 @@ def test_context_parallel_allgather_attn( qkv_layout, load_balanced, CPStrategy.ALL_GATHER, - use_shardy=False, ) @pytest_parametrize_wrapper( @@ -664,53 +576,11 @@ def test_context_parallel_ring_attn( qkv_layout, load_balanced, CPStrategy.RING, - use_shardy=False, use_scan_ring=use_scan, window_size=window_size, stripe_size=stripe_size, ) - @pytest_parametrize_wrapper( - "device_count,mesh_shape,mesh_axes,mesh_resource", - generate_context_parallel_configs_for_attn(), - ) - @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1]) - @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) - @pytest.mark.parametrize( - "qkv_layout, attn_mask_type", - DISTRIBUTED_CONTEXT_SELF_ATTN_LAYOUTS_MASKS, - ) - def test_context_parallel_ring_attn_shardy( - self, - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - data_shape, - attn_mask_type, - dtype, - qkv_layout, - ): - kv_groups = 8 - # Set the stripe size to 1 (ring attention only support stripe_size=1) - stripe_size = 1 if qkv_layout.is_thd() else None - self.impl_test_context_parallel_attn( - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - data_shape, - kv_groups, - attn_mask_type, - dtype, - qkv_layout, - load_balanced=True, - cp_strategy=CPStrategy.RING, - use_shardy=False, - use_scan_ring=True, - stripe_size=stripe_size, - ) - REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES = { "L0": [[]], diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index e9a2fa49e2..bb1f38dcc8 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -87,7 +87,6 @@ def generate_collectives_count_ref( @pytest_parametrize_wrapper("zero_centered_gamma", [False, True]) @pytest_parametrize_wrapper("shard_weights", [False, True]) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) - @pytest_parametrize_wrapper("use_shardy", [False, True]) def test_layernorm( self, device_count, @@ -99,9 +98,7 @@ def test_layernorm( zero_centered_gamma, shard_weights, fp8_recipe, - use_shardy, ): - jax.config.update("jax_use_shardy_partitioner", use_shardy) epsilon = 1e-6 ln_type = "layernorm" q_dtype = jnp.float8_e4m3fn @@ -178,7 +175,6 @@ def ref_func(x, gamma, beta): @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("shard_weights", [False, True]) @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) - @pytest_parametrize_wrapper("use_shardy", [False, True]) def test_rmsnorm( self, device_count, @@ -189,9 +185,7 @@ def test_rmsnorm( dtype, shard_weights, fp8_recipe, - use_shardy, ): - jax.config.update("jax_use_shardy_partitioner", use_shardy) epsilon = 1e-6 ln_type = "rmsnorm" q_dtype = jnp.float8_e4m3fn diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index d214597cb3..abf579d48e 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -192,10 +192,8 @@ def _test_layernorm_mlp_grad( input_shape, dtype, quantization_recipe, - use_shardy, with_jax_gemm, ): - jax.config.update("jax_use_shardy_partitioner", use_shardy) device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config layernorm_type = "rmsnorm" @@ -313,36 +311,6 @@ def test_layernorm_mlp_grad( dtype, quantization_recipe, with_jax_gemm, - ): - if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): - pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") - self._test_layernorm_mlp_grad( - mesh_config, - activation_type, - use_bias, - input_shape, - dtype, - quantization_recipe, - use_shardy=False, - with_jax_gemm=with_jax_gemm, - ) - - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) - @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) - @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) - @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("quantization_recipe", [None] + SUPPORTED_RECIPES) - @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_layernorm_mlp_grad_shardy( - self, - mesh_config, - activation_type, - use_bias, - input_shape, - dtype, - quantization_recipe, - with_jax_gemm, ): if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") @@ -353,7 +321,6 @@ def test_layernorm_mlp_grad_shardy( input_shape, dtype, quantization_recipe=quantization_recipe, - use_shardy=True, with_jax_gemm=with_jax_gemm, ) @@ -366,10 +333,8 @@ def _test_layernorm_mlp( dtype, use_fp8, quantization_recipe, - use_shardy, with_jax_gemm, ): - jax.config.update("jax_use_shardy_partitioner", use_shardy) batch, seqlen, hidden_in = input_shape layernorm_type = "rmsnorm" @@ -481,7 +446,6 @@ def test_layernorm_mlp_layer( dtype, use_fp8=False, quantization_recipe=None, - use_shardy=False, with_jax_gemm=with_jax_gemm, ) @@ -512,58 +476,5 @@ def test_layernorm_mlp_layer_fp8( dtype, use_fp8=True, quantization_recipe=quantization_recipe, - use_shardy=False, - with_jax_gemm=with_jax_gemm, - ) - - @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) - @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) - @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_layernorm_mlp_layer_shardy( - self, mesh_config, activation_type, use_bias, input_shape, dtype, with_jax_gemm - ): - self._test_layernorm_mlp( - mesh_config, - activation_type, - use_bias, - input_shape, - dtype, - use_fp8=False, - quantization_recipe=None, - use_shardy=True, - with_jax_gemm=with_jax_gemm, - ) - - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) - @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) - @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) - @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES) - @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_layernorm_mlp_layer_fp8_shardy( - self, - mesh_config, - activation_type, - use_bias, - input_shape, - dtype, - quantization_recipe, - with_jax_gemm, - ): - if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): - pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") - self._test_layernorm_mlp( - mesh_config, - activation_type, - use_bias, - input_shape, - dtype, - use_fp8=True, - quantization_recipe=quantization_recipe, - use_shardy=True, with_jax_gemm=with_jax_gemm, ) diff --git a/tests/jax/test_distributed_permutation.py b/tests/jax/test_distributed_permutation.py index 5b6d8fec47..04ed236e81 100644 --- a/tests/jax/test_distributed_permutation.py +++ b/tests/jax/test_distributed_permutation.py @@ -135,7 +135,6 @@ def generate_routing_map( DISPATCH_COMBINE_CASES, ) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("use_shardy", [False, True]) def test_local_token_dispatch( self, device_count, @@ -147,7 +146,6 @@ def test_local_token_dispatch( hidden_size, topk, dtype, - use_shardy, ): """ Test token_dispatch with sharded inputs. @@ -164,7 +162,6 @@ def test_local_token_dispatch( matching the sharded execution's output ordering. Tests both forward pass (output values) and backward pass (gradients). """ - jax.config.update("jax_use_shardy_partitioner", use_shardy) key = jax.random.PRNGKey(42) # Generate global inputs @@ -307,7 +304,6 @@ def ref_chunk_loss(inp_chunk, routing_chunk, probs_chunk): DISPATCH_COMBINE_CASES, ) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("use_shardy", [False, True]) def test_local_roundtrip( self, device_count, @@ -319,7 +315,6 @@ def test_local_roundtrip( hidden_size, topk, dtype, - use_shardy, ): """ Test roundtrip: token_dispatch followed by token_combine with sharded inputs. @@ -332,7 +327,6 @@ def test_local_roundtrip( Tests both forward pass and backward pass (gradient should be 2*x). """ - jax.config.update("jax_use_shardy_partitioner", use_shardy) key = jax.random.PRNGKey(42) # Generate global inputs @@ -403,7 +397,6 @@ def roundtrip_loss(x, rm, mprobs): DISPATCH_COMBINE_PADDING_CASES, ) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("use_shardy", [False, True]) def test_local_token_dispatch_with_padding( self, device_count, @@ -416,14 +409,12 @@ def test_local_token_dispatch_with_padding( topk, align_size, dtype, - use_shardy, ): """ Test token_dispatch with padding using sharded inputs. Tests both forward pass (output values) and backward pass (gradients). """ - jax.config.update("jax_use_shardy_partitioner", use_shardy) key = jax.random.PRNGKey(42) # Generate global inputs @@ -502,7 +493,6 @@ def loss_with_padding(x, rm, p): DISPATCH_COMBINE_PADDING_CASES, ) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("use_shardy", [False, True]) def test_local_roundtrip_with_padding( self, device_count, @@ -515,7 +505,6 @@ def test_local_roundtrip_with_padding( topk, align_size, dtype, - use_shardy, ): """ Test roundtrip with padding/alignment using sharded inputs. @@ -523,7 +512,6 @@ def test_local_roundtrip_with_padding( With uniform merging probs, should recover original input. Tests both forward pass and backward pass. """ - jax.config.update("jax_use_shardy_partitioner", use_shardy) key = jax.random.PRNGKey(42) # Generate inputs diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py index 0665baa4e3..ca1dcf1174 100644 --- a/tests/jax/test_distributed_softmax.py +++ b/tests/jax/test_distributed_softmax.py @@ -87,12 +87,9 @@ def impl_test_softmax( dtype, bad_sharding, broadcast_batch_mask, - use_shardy, ): if broadcast_batch_mask and softmax_fusion_type != SoftmaxFusionType.SCALED_MASKED: pytest.skip("Softmax type has no mask.") - - jax.config.update("jax_use_shardy_partitioner", use_shardy) target_func = partial( self.target_func, scale_factor=scale_factor, softmax_fusion_type=softmax_fusion_type ) @@ -181,35 +178,4 @@ def test_softmax( dtype, bad_sharding, broadcast_batch_mask, - use_shardy=True, - ) - - @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize( - "softmax_fusion_type", [SoftmaxFusionType.SCALED, SoftmaxFusionType.SCALED_MASKED] - ) - @pytest.mark.parametrize("bad_sharding", [False, True]) - @pytest.mark.parametrize("broadcast_batch_mask", [False, True]) - def test_softmax_gspmd( - self, - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - softmax_fusion_type, - bad_sharding, - broadcast_batch_mask, - ): - self.impl_test_softmax( - device_count, - mesh_shape, - mesh_axes, - mesh_resource, - data_shape=[32, 12, 128, 128], - softmax_fusion_type=softmax_fusion_type, - scale_factor=1.0, - dtype=DTYPES[0], - bad_sharding=bad_sharding, - broadcast_batch_mask=broadcast_batch_mask, - use_shardy=False, ) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index b26e01c0c7..ae3888cf04 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -8,15 +8,23 @@ from abc import ABCMeta, abstractmethod from functools import partial +import jax from jax.extend import core from jax.interpreters import xla, mlir from jax.experimental.custom_partitioning import custom_partitioning from jax._src.interpreters import batching from jax._src import dispatch from jax import ffi +from packaging.version import Version as PkgVersion import transformer_engine_jax +# GSPMD sharding propagation (infer_sharding_from_operands) is removed in JAX > 0.9.1. +# Only register it for older JAX versions to maintain backwards compatibility. +# For JAX > 0.9.1, infer_sharding_from_operands is also removed from def_partition's signature, +# so it must not be passed at all. +_JAX_GSPMD_SUPPORTED = PkgVersion(jax.__version__) <= PkgVersion("0.9.1") + class BasePrimitive(metaclass=ABCMeta): """ @@ -143,13 +151,15 @@ def batcher(): """ return NotImplemented - @staticmethod - @abstractmethod - def infer_sharding_from_operands(): + @classmethod + def infer_sharding_from_operands(cls, *args, **kwargs): """ to describe infer_sharding_from_operands for custom_partitioning """ - return NotImplemented + raise NotImplementedError( + f"{cls.__name__} does not support GSPMD sharding propagation." + " Please use Shardy partitioner instead." + ) @staticmethod @abstractmethod @@ -172,6 +182,22 @@ def shardy_sharding_rule(*args): # Registry to store all registered primitive classes _primitive_registry = {} +_gspmd_deprecation_warned = False + + +def _warn_gspmd_deprecation_once(): + global _gspmd_deprecation_warned + if not _gspmd_deprecation_warned: + warnings.warn( + "GSPMD sharding propagation is planned to be removed in June 2026." + " It is no longer maintained or tested. Use it at your own risk." + " Please use Shardy partitioner instead." + " In case you cannot upgrade to a JAX version that supports Shardy, please reach out!", + DeprecationWarning, + stacklevel=3, + ) + _gspmd_deprecation_warned = True + def register_primitive(cls, outer_only=False): """ @@ -208,10 +234,16 @@ def name_of_wrapper_p(): outer_p.def_abstract_eval(cls.outer_abstract) batching.primitive_batchers[outer_p] = cls.batcher outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) + if _JAX_GSPMD_SUPPORTED: + if "infer_sharding_from_operands" in cls.__dict__: + _warn_gspmd_deprecation_once() + gspmd_kwargs = {"infer_sharding_from_operands": cls.infer_sharding_from_operands} + else: + gspmd_kwargs = {} outer_p_lower.def_partition( - infer_sharding_from_operands=cls.infer_sharding_from_operands, partition=cls.partition, sharding_rule=cls.shardy_sharding_rule, + **gspmd_kwargs, ) mlir.register_lowering( outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results)