Skip to content

[PyTorch] Remove is_first_microbatch setting after cudagraph warmup#2715

Open
buptzyb wants to merge 1 commit intoNVIDIA:mainfrom
buptzyb:robinz/graph_isfirstmicrobatch
Open

[PyTorch] Remove is_first_microbatch setting after cudagraph warmup#2715
buptzyb wants to merge 1 commit intoNVIDIA:mainfrom
buptzyb:robinz/graph_isfirstmicrobatch

Conversation

@buptzyb
Copy link
Contributor

@buptzyb buptzyb commented Feb 27, 2026

Description

Reset is_first_microbatch to True is unnecessary, because later we'll set it back to False: https://github.com/NVIDIA/TransformerEngine/blob/release_v2.12/transformer_engine/pytorch/module/layernorm_mlp.py#L2037-L2042

is_first_microbatch has two functions: the first is to control fp8 cast-transpose behavior, which is overtaken by the skip_fp8_weight_update tensor in graph replay. The second is to control whether the fused wgrad should clear main_grad buffer before making the accumulation. We should always capture the no-clear version in case the user captures one graph for all microbatches.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Robin Zhang <robinz@nvidia.com>
@buptzyb buptzyb force-pushed the robinz/graph_isfirstmicrobatch branch from 886f8cb to 25af05a Compare February 27, 2026 08:27
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 27, 2026

Greptile Summary

Removes unnecessary code that was resetting is_first_microbatch to True after CUDA graph warmup iterations.

The removed code was originally added for Megatron-Core requirements to prevent warmup from altering control flow. However, this reset is actually unnecessary and potentially harmful:

  • Why unnecessary: During graph replay, skip_fp8_weight_update tensor controls FP8 weight updates, and modules automatically set is_first_microbatch=False when this tensor is present (e.g., layernorm_mlp.py:2053-2054)
  • Why harmful: Resetting to True forces capturing the "clear" version of gradient accumulation (for first microbatch only), but for graphs reused across all microbatches, the "no-clear" version is needed to properly accumulate gradients

The two functions of is_first_microbatch are:

  1. Control FP8 cast-transpose behavior → now handled by skip_fp8_weight_update during graph replay
  2. Control fused wgrad clearing of main_grad buffer → should capture no-clear version for multi-microbatch usage

This fix ensures correct gradient accumulation when one CUDA graph is captured and replayed for multiple microbatches.

Confidence Score: 5/5

  • This PR is safe to merge with no identified risks
  • The change removes unnecessary code that was actually causing incorrect behavior. The PR author provides clear technical justification, the logic aligns with how skip_fp8_weight_update controls behavior during graph replay, and the removed code was interfering with proper gradient accumulation in multi-microbatch scenarios. The change is small (5 lines), well-understood, and fixes a bug rather than introducing new functionality.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/graph.py Removed unnecessary reset of is_first_microbatch after warmup, fixing cudagraph capture for multi-microbatch scenarios

Last reviewed commit: 25af05a

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant