Hongbinl/offload activation cuda graph mxfp8 offload fix#2716
Hongbinl/offload activation cuda graph mxfp8 offload fix#2716lhb8125 wants to merge 25 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
…ub.com/lhb8125/TransformerEngine into hongbinl/offload_activation_cuda_graph
…ub.com/lhb8125/TransformerEngine into hongbinl/offload_activation_cuda_graph
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
…ub.com/lhb8125/TransformerEngine into hongbinl/offload_activation_cuda_graph
for more information, see https://pre-commit.ci
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
…ub.com/lhb8125/TransformerEngine into hongbinl/offload_activation_cuda_graph
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
…ub.com/lhb8125/TransformerEngine into hongbinl/offload_activation_cuda_graph
for more information, see https://pre-commit.ci
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR enhances CPU offload for CUDA graphs with MXFP8 quantization by switching from individual tensor offload to buffer-based offload with subview restoration. Major changes:
Critical issues found:
Confidence Score: 1/5
Important Files Changed
Last reviewed commit: 484b0d5 |
| if col_scale_info is not None: | ||
| byte_offset, shape, stride, _ = col_scale_info | ||
| tensor._columnwise_scale_inv = columnwise_buffer.as_strided(shape, stride, byte_offset) |
There was a problem hiding this comment.
dtype stored in col_scale_info but not used. as_strided on uint8 buffer returns uint8 tensor, but _columnwise_scale_inv should be float32 (for Float8Blockwise) or uint8 (for MXFP8). This will cause type mismatches when the restored tensor is used.
| if col_scale_info is not None: | |
| byte_offset, shape, stride, _ = col_scale_info | |
| tensor._columnwise_scale_inv = columnwise_buffer.as_strided(shape, stride, byte_offset) | |
| if col_scale_info is not None: | |
| byte_offset, shape, stride, dtype = col_scale_info | |
| tensor._columnwise_scale_inv = columnwise_buffer.as_strided(shape, stride, byte_offset).view(dtype) |
| if col_amax_info is not None: | ||
| byte_offset, shape, stride, _ = col_amax_info | ||
| tensor._columnwise_amax = columnwise_buffer.as_strided(shape, stride, byte_offset) |
There was a problem hiding this comment.
Same dtype issue: _columnwise_amax should be float32 for NVFP4, but will be restored as uint8
| if col_amax_info is not None: | |
| byte_offset, shape, stride, _ = col_amax_info | |
| tensor._columnwise_amax = columnwise_buffer.as_strided(shape, stride, byte_offset) | |
| if col_amax_info is not None: | |
| byte_offset, shape, stride, dtype = col_amax_info | |
| tensor._columnwise_amax = columnwise_buffer.as_strided(shape, stride, byte_offset).view(dtype) |
| cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | ||
| with cuda_graph_stream: | ||
| fwd_graph.replay() | ||
| torch.cuda.current_stream().wait_event(cuda_graph_event) |
There was a problem hiding this comment.
Missing cuda_graph_event.record(cuda_graph_stream) after replay. Without recording the event, wait_event waits for the wrong completion point
| cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | |
| with cuda_graph_stream: | |
| fwd_graph.replay() | |
| torch.cuda.current_stream().wait_event(cuda_graph_event) | |
| cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | |
| with cuda_graph_stream: | |
| fwd_graph.replay() | |
| cuda_graph_event.record(cuda_graph_stream) | |
| torch.cuda.current_stream().wait_event(cuda_graph_event) |
| ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | ||
| with ctx.cuda_graph_stream: | ||
| bwd_graph.replay() | ||
| torch.cuda.current_stream().wait_event(ctx.cuda_graph_event) |
There was a problem hiding this comment.
Same issue: missing ctx.cuda_graph_event.record(ctx.cuda_graph_stream) after backward graph replay
| ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | |
| with ctx.cuda_graph_stream: | |
| bwd_graph.replay() | |
| torch.cuda.current_stream().wait_event(ctx.cuda_graph_event) | |
| ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream()) | |
| with ctx.cuda_graph_stream: | |
| bwd_graph.replay() | |
| ctx.cuda_graph_event.record(ctx.cuda_graph_stream) | |
| torch.cuda.current_stream().wait_event(ctx.cuda_graph_event) |
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: