@@ -65,3 +65,116 @@ def stochastic_rounding(value, dtype, seed=0):
6565 return output
6666
6767 return value .to (dtype = dtype )
68+
69+
70+ # TODO: improve this?
71+ def stochastic_float_to_fp4_e2m1 (x , generator ):
72+ sign = torch .signbit (x ).to (torch .uint8 )
73+ x_abs = x .abs ()
74+
75+ exp = torch .floor (torch .log2 (x_abs ) + 1.0 ).clamp (0 , 3 )
76+ x += (torch .rand (x .size (), dtype = x .dtype , layout = x .layout , device = x .device , generator = generator ) - 0.5 ) * (2 ** (exp - 2.0 )) * 1.25
77+
78+ x_abs = x .abs ()
79+ exp = torch .floor (torch .log2 (x_abs ) + 1.1925 ).clamp (0 , 3 )
80+
81+ mantissa = torch .where (
82+ exp > 0 ,
83+ (x_abs / (2.0 ** (exp - 1 )) - 1.0 ) * 2.0 ,
84+ (x_abs * 2.0 )
85+ ).round ().to (torch .uint8 )
86+
87+ fp4 = (sign << 3 ) | (exp .to (torch .uint8 ) << 1 ) | mantissa
88+
89+ fp4_flat = fp4 .view (- 1 )
90+ packed = (fp4_flat [0 ::2 ] << 4 ) | fp4_flat [1 ::2 ]
91+ return packed .reshape (list (x .shape )[:- 1 ] + [- 1 ])
92+
93+
94+ def to_blocked (input_matrix , flatten : bool = True ) -> torch .Tensor :
95+ """
96+ Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern.
97+ See:
98+ https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
99+
100+ Args:
101+ input_matrix: Input tensor of shape (H, W)
102+ Returns:
103+ Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4))
104+ """
105+
106+ def ceil_div (a , b ):
107+ return (a + b - 1 ) // b
108+
109+ rows , cols = input_matrix .shape
110+ n_row_blocks = ceil_div (rows , 128 )
111+ n_col_blocks = ceil_div (cols , 4 )
112+
113+ # Calculate the padded shape
114+ padded_rows = n_row_blocks * 128
115+ padded_cols = n_col_blocks * 4
116+
117+ padded = input_matrix
118+ if (rows , cols ) != (padded_rows , padded_cols ):
119+ padded = torch .zeros (
120+ (padded_rows , padded_cols ),
121+ device = input_matrix .device ,
122+ dtype = input_matrix .dtype ,
123+ )
124+ padded [:rows , :cols ] = input_matrix
125+
126+ # Rearrange the blocks
127+ blocks = padded .view (n_row_blocks , 128 , n_col_blocks , 4 ).permute (0 , 2 , 1 , 3 )
128+ rearranged = blocks .reshape (- 1 , 4 , 32 , 4 ).transpose (1 , 2 ).reshape (- 1 , 32 , 16 )
129+ if flatten :
130+ return rearranged .flatten ()
131+
132+ return rearranged .reshape (padded_rows , padded_cols )
133+
134+
135+ def stochastic_round_quantize_nvfp4 (x , per_tensor_scale , pad_16x , seed = 0 ):
136+ F4_E2M1_MAX = 6.0
137+ F8_E4M3_MAX = 448.0
138+
139+ def roundup (x : int , multiple : int ) -> int :
140+ """Round up x to the nearest multiple."""
141+ return ((x + multiple - 1 ) // multiple ) * multiple
142+
143+ orig_shape = x .shape
144+
145+ # Handle padding
146+ if pad_16x :
147+ rows , cols = x .shape
148+ padded_rows = roundup (rows , 16 )
149+ padded_cols = roundup (cols , 16 )
150+ if padded_rows != rows or padded_cols != cols :
151+ x = torch .nn .functional .pad (x , (0 , padded_cols - cols , 0 , padded_rows - rows ))
152+ # Note: We update orig_shape because the output tensor logic below assumes x.shape matches
153+ # what we want to produce. If we pad here, we want the padded output.
154+ orig_shape = x .shape
155+
156+ block_size = 16
157+
158+ x = x .reshape (orig_shape [0 ], - 1 , block_size )
159+ max_abs = torch .amax (torch .abs (x ), dim = - 1 )
160+ block_scale = max_abs / F4_E2M1_MAX
161+ scaled_block_scales = block_scale / per_tensor_scale .to (block_scale .dtype )
162+ scaled_block_scales_fp8 = torch .clamp (scaled_block_scales , max = F8_E4M3_MAX ).to (torch .float8_e4m3fn )
163+ total_scale = per_tensor_scale .to (x .dtype ) * scaled_block_scales_fp8 .to (x .dtype )
164+
165+ # Handle zero blocks (from padding): avoid 0/0 NaN
166+ zero_scale_mask = (total_scale == 0 )
167+ total_scale_safe = torch .where (zero_scale_mask , torch .ones_like (total_scale ), total_scale )
168+
169+ x = x / total_scale_safe .unsqueeze (- 1 )
170+
171+ generator = torch .Generator (device = x .device )
172+ generator .manual_seed (seed )
173+
174+ x = torch .where (zero_scale_mask .unsqueeze (- 1 ), torch .zeros_like (x ), x )
175+
176+ x = x .view (orig_shape )
177+ data_lp = stochastic_float_to_fp4_e2m1 (x , generator = generator )
178+
179+ blocked_scales = to_blocked (scaled_block_scales_fp8 , flatten = False )
180+ return data_lp , blocked_scales
0 commit comments