Skip to content

[JAX] Deprecate GSPMD: remove infer_sharding_from_operands and GSPMD tests#2702

Open
phu0ngng wants to merge 11 commits intoNVIDIA:mainfrom
phu0ngng:rm_gspmd
Open

[JAX] Deprecate GSPMD: remove infer_sharding_from_operands and GSPMD tests#2702
phu0ngng wants to merge 11 commits intoNVIDIA:mainfrom
phu0ngng:rm_gspmd

Conversation

@phu0ngng
Copy link
Collaborator

Description

GSPMD sharding propagation is being deprecated in favour of Shardy, which is now the default JAX partitioner. This commit removes all GSPMD-related code paths and tests:

  • Drop the infer_sharding_from_operands abstract method from BasePrimitive and remove it from def_partition() registration
  • Remove all infer_sharding_from_operands implementations across cpp_extensions: activation, amax, attention, gemm, normalization, quantization, and softmax primitives
  • Remove stale "Keep in sync with infer_sharding_from_operands" comments from FusedAttn shardy_sharding_rule methods
  • Drop all use_shardy=False (GSPMD) distributed test paths and the jax.config.update("jax_use_shardy_partitioner", ...) config calls
  • Consolidate paired GSPMD/Shardy test functions into single tests and strip _shardy suffixes from test names

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 24, 2026

Greptile Summary

This PR removes GSPMD sharding propagation in favor of Shardy as the default JAX partitioner. The changes are comprehensive and well-structured:

Core Changes:

  • Removes the infer_sharding_from_operands abstract method from BasePrimitive in base.py
  • Removes all infer_sharding_from_operands implementations across cpp_extensions (activation, amax, attention, gemm, normalization, quantization, softmax)
  • Removes stale "Keep in sync with infer_sharding_from_operands" comments from FusedAttn shardy_sharding_rule methods
  • Consolidates paired GSPMD/Shardy test functions by removing use_shardy parameters and _shardy suffixed test variants

Behavioral Changes:

  • In gemm.py, replaces NotImplementedError with a UserWarning for CollectiveGEMM with Shardy, allowing execution but warning about potentially incorrect sharding patterns
  • Adds similar warnings in dense.py and layernorm_mlp.py for CollectiveGEMM usage without explicit output sharding constraints
  • Updates test_gemm.py to provide explicit output_sharding instead of relying on automatic inference

Test Coverage:

  • All distributed tests updated to use Shardy exclusively by removing jax.config.update("jax_use_shardy_partitioner", ...) calls
  • Example scripts remove --enable-shardy CLI arguments and associated test variants

Issues Found:

  • The triton_extensions/permutation.py file (not in this PR) still contains infer_sharding_from_operands methods that will no longer be registered after this change

Confidence Score: 4/5

  • This PR is generally safe to merge with one important consideration regarding triton extensions
  • Score reflects thorough removal of GSPMD code paths and comprehensive test updates. Reduced from 5 due to triton_extensions/permutation.py primitives still defining now-unused infer_sharding_from_operands methods that should be cleaned up for consistency
  • transformer_engine/jax/triton_extensions/permutation.py (not in this PR) should be updated in a follow-up to remove unused infer_sharding_from_operands methods

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/base.py Cleanly removes infer_sharding_from_operands abstract method and its registration in def_partition() call
transformer_engine/jax/cpp_extensions/gemm.py Removes infer_sharding_from_operands method and replaces NotImplementedError with warning for CollectiveGEMM with Shardy
transformer_engine/jax/cpp_extensions/attention.py Removes infer_sharding_from_operands from FusedAttnFwd/BwdPrimitive and cleanup stale sync comments
transformer_engine/jax/dense.py Adds warning for CollectiveGEMM with Shardy when output_axes not set to prevent incorrect sharding patterns
transformer_engine/jax/layernorm_mlp.py Adds warning for CollectiveGEMM with Shardy when dot_2_input_axes not set to ensure correct output sharding
examples/jax/collective_gemm/test_gemm.py Removes Shardy disable config and changes from output_sharding=None to explicit output_sharding in test

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Primitive Registration] --> B[register_primitive called]
    B --> C[def_partition setup]
    
    C --> D{Before This PR}
    C --> E{After This PR}
    
    D --> F[GSPMD Path Available]
    D --> G[Shardy Path Available]
    
    F --> H[infer_sharding_from_operands<br/>registered and used]
    G --> I[shardy_sharding_rule<br/>registered and used]
    
    H --> J{JAX Config:<br/>use_shardy_partitioner}
    I --> J
    
    J -->|False| K[Use GSPMD:<br/>infer_sharding_from_operands]
    J -->|True| L[Use Shardy:<br/>shardy_sharding_rule]
    
    E --> M[Only Shardy Path]
    M --> N[shardy_sharding_rule<br/>registered and used]
    M --> O[infer_sharding_from_operands<br/>removed, not registered]
    
    N --> P[Always use Shardy<br/>No config needed]
    
    style H fill:#ffcccc
    style K fill:#ffcccc
    style O fill:#ffcccc
    style P fill:#ccffcc
    style N fill:#ccffcc
Loading

Last reviewed commit: 81a2bfa

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

21 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@phu0ngng
Copy link
Collaborator Author

/te-ci JAX L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

12 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

19 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

phu0ngng and others added 11 commits February 26, 2026 16:29
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…tests

GSPMD sharding propagation is being deprecated in favour of Shardy,
which is now the default JAX partitioner. This commit removes all
GSPMD-related code paths and tests:

- Drop the infer_sharding_from_operands abstract method from
  BasePrimitive and remove it from def_partition() registration
- Remove all infer_sharding_from_operands implementations across
  cpp_extensions: activation, amax, attention, gemm, normalization,
  quantization, and softmax primitives
- Remove stale "Keep in sync with infer_sharding_from_operands"
  comments from FusedAttn shardy_sharding_rule methods
- Drop all use_shardy=False (GSPMD) distributed test paths and the
  jax.config.update("jax_use_shardy_partitioner", ...) config calls
- Consolidate paired GSPMD/Shardy test functions into single tests
  and strip _shardy suffixes from test names

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Comment on lines 201 to 206
batching.primitive_batchers[outer_p] = cls.batcher
outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
outer_p_lower.def_partition(
infer_sharding_from_operands=cls.infer_sharding_from_operands,
partition=cls.partition,
sharding_rule=cls.shardy_sharding_rule,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing infer_sharding_from_operands from def_partition() will affect all primitives that inherit from BasePrimitive. The file transformer_engine/jax/triton_extensions/permutation.py contains multiple primitives (RowIdMapPass1Primitive, RowIdMapPass2Primitive, etc.) that still define infer_sharding_from_operands methods but are not updated in this PR. After this change, those methods will no longer be registered or used, potentially causing different sharding behavior for triton extension primitives.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant