Skip to content

Hongbinl/offload activation cuda graph mxfp8 offload fix#2716

Draft
lhb8125 wants to merge 25 commits intoNVIDIA:mainfrom
lhb8125:hongbinl/offload_activation_cuda_graph_mxfp8_offload_fix
Draft

Hongbinl/offload activation cuda graph mxfp8 offload fix#2716
lhb8125 wants to merge 25 commits intoNVIDIA:mainfrom
lhb8125:hongbinl/offload_activation_cuda_graph_mxfp8_offload_fix

Conversation

@lhb8125
Copy link
Contributor

@lhb8125 lhb8125 commented Feb 27, 2026

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

lhb8125 and others added 24 commits November 30, 2025 21:33
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: root <root@eos0046.eos.clusters.nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com>
@lhb8125 lhb8125 marked this pull request as draft February 27, 2026 11:09
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 27, 2026

Greptile Summary

This PR enhances CPU offload for CUDA graphs with MXFP8 quantization by switching from individual tensor offload to buffer-based offload with subview restoration.

Major changes:

  • Modified split_quantize to return (tensor_list, buffer_list) tuple instead of just tensor list
  • Added get_columnwise_subview_info() and restore_columnwise_subviews() for tracking/restoring tensor views after offload
  • Changed grouped_linear to offload buffers instead of individual tensors when CPU offloading is enabled
  • Added CUDA stream synchronization for graph replay with cuda_graph_stream and cuda_graph_event parameters
  • Commented out assertion in float8_blockwise_tensor_storage.py to allow both usage flags to be False during buffer-based offload

Critical issues found:

  • restore_columnwise_subviews() captures dtype info but doesn't use it - restored tensors will have wrong dtype (uint8 instead of float32 for scales/amax)
  • Missing event.record() calls after graph replay in forward and backward - synchronization won't work correctly
  • API breaking change: split_quantize now returns tuple, but transformer_engine/pytorch/ops/basic/grouped_linear.py (lines 527, 612) wasn't updated and will fail at runtime

Confidence Score: 1/5

  • This PR contains multiple critical bugs that will cause runtime failures
  • Three critical logic bugs: (1) dtype restoration bug causes wrong tensor types after offload, (2) missing CUDA event recording breaks graph synchronization, (3) API change breaks existing code in ops/basic/grouped_linear.py which wasn't updated
  • Critical: transformer_engine/pytorch/quantized_tensor.py (dtype bug), transformer_engine/pytorch/graph.py (sync bug), and transformer_engine/pytorch/ops/basic/grouped_linear.py (not in PR but will be broken)

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantized_tensor.py Adds subview tracking/restoration functions for CPU offload, but dtype information is captured but not used during restoration
transformer_engine/pytorch/module/grouped_linear.py Changes CPU offload strategy to offload buffers instead of individual tensors, adds subview restoration in backward pass
transformer_engine/pytorch/ops/basic/grouped_linear.py NOT CHANGED IN THIS PR but has critical bug: split_quantize calls not updated to unpack tuple return value (lines 527, 612)
transformer_engine/pytorch/graph.py Adds pre/post warmup hooks and CUDA stream synchronization for graph replay, but missing event.record() after graph replay
transformer_engine/pytorch/csrc/extensions/cast.cpp Updates bulk allocation functions to return buffer list alongside tensor list for CPU offload support
transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py Comments out assertion requiring at least one of columnwise/rowwise usage; needed for buffer-based offload where both flags are False

Last reviewed commit: 484b0d5

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

11 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +268 to +270
if col_scale_info is not None:
byte_offset, shape, stride, _ = col_scale_info
tensor._columnwise_scale_inv = columnwise_buffer.as_strided(shape, stride, byte_offset)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype stored in col_scale_info but not used. as_strided on uint8 buffer returns uint8 tensor, but _columnwise_scale_inv should be float32 (for Float8Blockwise) or uint8 (for MXFP8). This will cause type mismatches when the restored tensor is used.

Suggested change
if col_scale_info is not None:
byte_offset, shape, stride, _ = col_scale_info
tensor._columnwise_scale_inv = columnwise_buffer.as_strided(shape, stride, byte_offset)
if col_scale_info is not None:
byte_offset, shape, stride, dtype = col_scale_info
tensor._columnwise_scale_inv = columnwise_buffer.as_strided(shape, stride, byte_offset).view(dtype)

Comment on lines +273 to +275
if col_amax_info is not None:
byte_offset, shape, stride, _ = col_amax_info
tensor._columnwise_amax = columnwise_buffer.as_strided(shape, stride, byte_offset)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same dtype issue: _columnwise_amax should be float32 for NVFP4, but will be restored as uint8

Suggested change
if col_amax_info is not None:
byte_offset, shape, stride, _ = col_amax_info
tensor._columnwise_amax = columnwise_buffer.as_strided(shape, stride, byte_offset)
if col_amax_info is not None:
byte_offset, shape, stride, dtype = col_amax_info
tensor._columnwise_amax = columnwise_buffer.as_strided(shape, stride, byte_offset).view(dtype)

Comment on lines +808 to +811
cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with cuda_graph_stream:
fwd_graph.replay()
torch.cuda.current_stream().wait_event(cuda_graph_event)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing cuda_graph_event.record(cuda_graph_stream) after replay. Without recording the event, wait_event waits for the wrong completion point

Suggested change
cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with cuda_graph_stream:
fwd_graph.replay()
torch.cuda.current_stream().wait_event(cuda_graph_event)
cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with cuda_graph_stream:
fwd_graph.replay()
cuda_graph_event.record(cuda_graph_stream)
torch.cuda.current_stream().wait_event(cuda_graph_event)

Comment on lines +828 to +831
ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with ctx.cuda_graph_stream:
bwd_graph.replay()
torch.cuda.current_stream().wait_event(ctx.cuda_graph_event)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue: missing ctx.cuda_graph_event.record(ctx.cuda_graph_stream) after backward graph replay

Suggested change
ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with ctx.cuda_graph_stream:
bwd_graph.replay()
torch.cuda.current_stream().wait_event(ctx.cuda_graph_event)
ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with ctx.cuda_graph_stream:
bwd_graph.replay()
ctx.cuda_graph_event.record(ctx.cuda_graph_stream)
torch.cuda.current_stream().wait_event(ctx.cuda_graph_event)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant