[DRAFT] SPMD PP gather weights and write custom vjp#3071
Draft
[DRAFT] SPMD PP gather weights and write custom vjp#3071
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
6c22238 to
28f98ff
Compare
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
28f98ff to
64b37ff
Compare
a7d38d0 to
9a36099
Compare
d05c015 to
e521a58
Compare
simple fix on debug sharding log add all gather insertion per repeat working all gather insertion clean version fsdp+pp bug free add bsw checkpoint split bsw all gather into two add custom vjp
e521a58 to
94812d9
Compare
51e6713 to
286e066
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR refactors the Pipeline Parallelism (PP) core logic to improve efficiency and memory management when using circular pipelining. The key highlights include the introduction of a Buffer Sliding Window (BSW) for weights, the implementation of a custom VJP for scanned pipeline iterations, and support for scanning over pipeline repeats.
Key Changes
1. Configuration & Types
scan_pipeline_repeatsto allowjax.lax.scanover pipeline repeats.deepseek_batchsplitconfiguration to include thestageaxis inmesh_axesandlogical_axis_rules.2. Pipeline Core (
pipeline.py)shard_map.nn.scanwith a custom-defined VJP for pipeline iterations. This allows for manual gradient checkpointing where the forward pass is re-run during the backward pass to save memory on heavy states.loop_statenow carriesbswandweightsthrough iterations.3. Layers & Utils
pipeline_utils.py: A new utility module containing helper functions for FSDP axis indexing, logical spec manipulation, and thecreate_scanned_functionfactory for the custom VJP.deepseek_batchsplit.pyto ensuremegablox.gmmis used when pipeline parallelism is enabled.attention_op.pyandmoe.pyto skip logical rules when using PP, relying instead on the pipeline's sharding logic.4. Testing
pipeline_parallelism_test.pyto reflect changes in supported configurations.Implementation Details
The custom VJP (
run_scanned_custom_bwd) performs a backward scan to accumulate gradients. It reconstructs thecurr_loop_stateby combining saved lightweight states with the originalbswandweightsbefore applyingjax.vjpto the iteration step function.Tests
Please describe how you tested this change, and include any instructions and/or
commands to reproduce.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.