[JAX] Deprecate GSPMD: remove infer_sharding_from_operands and GSPMD tests#2702
[JAX] Deprecate GSPMD: remove infer_sharding_from_operands and GSPMD tests#2702phu0ngng wants to merge 11 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR removes GSPMD sharding propagation in favor of Shardy as the default JAX partitioner. The changes are comprehensive and well-structured: Core Changes:
Behavioral Changes:
Test Coverage:
Issues Found:
Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Primitive Registration] --> B[register_primitive called]
B --> C[def_partition setup]
C --> D{Before This PR}
C --> E{After This PR}
D --> F[GSPMD Path Available]
D --> G[Shardy Path Available]
F --> H[infer_sharding_from_operands<br/>registered and used]
G --> I[shardy_sharding_rule<br/>registered and used]
H --> J{JAX Config:<br/>use_shardy_partitioner}
I --> J
J -->|False| K[Use GSPMD:<br/>infer_sharding_from_operands]
J -->|True| L[Use Shardy:<br/>shardy_sharding_rule]
E --> M[Only Shardy Path]
M --> N[shardy_sharding_rule<br/>registered and used]
M --> O[infer_sharding_from_operands<br/>removed, not registered]
N --> P[Always use Shardy<br/>No config needed]
style H fill:#ffcccc
style K fill:#ffcccc
style O fill:#ffcccc
style P fill:#ccffcc
style N fill:#ccffcc
Last reviewed commit: 81a2bfa |
|
/te-ci JAX L1 |
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
…tests
GSPMD sharding propagation is being deprecated in favour of Shardy,
which is now the default JAX partitioner. This commit removes all
GSPMD-related code paths and tests:
- Drop the infer_sharding_from_operands abstract method from
BasePrimitive and remove it from def_partition() registration
- Remove all infer_sharding_from_operands implementations across
cpp_extensions: activation, amax, attention, gemm, normalization,
quantization, and softmax primitives
- Remove stale "Keep in sync with infer_sharding_from_operands"
comments from FusedAttn shardy_sharding_rule methods
- Drop all use_shardy=False (GSPMD) distributed test paths and the
jax.config.update("jax_use_shardy_partitioner", ...) config calls
- Consolidate paired GSPMD/Shardy test functions into single tests
and strip _shardy suffixes from test names
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
| batching.primitive_batchers[outer_p] = cls.batcher | ||
| outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) | ||
| outer_p_lower.def_partition( | ||
| infer_sharding_from_operands=cls.infer_sharding_from_operands, | ||
| partition=cls.partition, | ||
| sharding_rule=cls.shardy_sharding_rule, | ||
| ) |
There was a problem hiding this comment.
Removing infer_sharding_from_operands from def_partition() will affect all primitives that inherit from BasePrimitive. The file transformer_engine/jax/triton_extensions/permutation.py contains multiple primitives (RowIdMapPass1Primitive, RowIdMapPass2Primitive, etc.) that still define infer_sharding_from_operands methods but are not updated in this PR. After this change, those methods will no longer be registered or used, potentially causing different sharding behavior for triton extension primitives.
Description
GSPMD sharding propagation is being deprecated in favour of Shardy, which is now the default JAX partitioner. This commit removes all GSPMD-related code paths and tests:
Type of change
Checklist: