Skip to content

[DRAFT] SPMD PP gather weights and write custom vjp#3071

Draft
NuojCheng wants to merge 5 commits intomainfrom
chengnuojin-pp-separate-weights
Draft

[DRAFT] SPMD PP gather weights and write custom vjp#3071
NuojCheng wants to merge 5 commits intomainfrom
chengnuojin-pp-separate-weights

Conversation

@NuojCheng
Copy link
Collaborator

@NuojCheng NuojCheng commented Feb 3, 2026

Description

This PR refactors the Pipeline Parallelism (PP) core logic to improve efficiency and memory management when using circular pipelining. The key highlights include the introduction of a Buffer Sliding Window (BSW) for weights, the implementation of a custom VJP for scanned pipeline iterations, and support for scanning over pipeline repeats.

Key Changes

1. Configuration & Types

  • New Config Options: Added scan_pipeline_repeats to allow jax.lax.scan over pipeline repeats.
  • Mesh Updates: Updated deepseek_batchsplit configuration to include the stage axis in mesh_axes and logical_axis_rules.

2. Pipeline Core (pipeline.py)

  • Buffer Sliding Window (BSW): Introduced BSW to manage weight gathering more efficiently. It maintains a buffer for weight copies that are all-gathered over the FSDP axis using shard_map.
  • Custom VJP Implementation: Replaced standard nn.scan with a custom-defined VJP for pipeline iterations. This allows for manual gradient checkpointing where the forward pass is re-run during the backward pass to save memory on heavy states.
  • Loop State Refactor: loop_state now carries bsw and weights through iterations.

3. Layers & Utils

  • pipeline_utils.py: A new utility module containing helper functions for FSDP axis indexing, logical spec manipulation, and the create_scanned_function factory for the custom VJP.
  • DeepSeek Support: Updated deepseek_batchsplit.py to ensure megablox.gmm is used when pipeline parallelism is enabled.
  • Layer Sharding: Modified attention_op.py and moe.py to skip logical rules when using PP, relying instead on the pipeline's sharding logic.

4. Testing

  • Updated pipeline_parallelism_test.py to reflect changes in supported configurations.
  • Added skips for non-circular pipelines and FP8 configurations which are currently incompatible with the new BSW/Custom VJP logic.

Implementation Details

The custom VJP (run_scanned_custom_bwd) performs a backward scan to accumulate gradients. It reconstructs the curr_loop_state by combining saved lightweight states with the original bsw and weights before applying jax.vjp to the iteration step function.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link

codecov bot commented Feb 3, 2026

Codecov Report

❌ Patch coverage is 88.88889% with 13 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/MaxText/layers/pipeline.py 88.88% 7 Missing and 6 partials ⚠️

📢 Thoughts on this report? Let us know!

@NuojCheng NuojCheng force-pushed the chengnuojin-pp-separate-weights branch 2 times, most recently from 6c22238 to 28f98ff Compare February 9, 2026 19:29
@codecov
Copy link

codecov bot commented Feb 9, 2026

Codecov Report

❌ Patch coverage is 86.56716% with 27 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/pipeline_utils.py 85.18% 8 Missing and 8 partials ⚠️
src/maxtext/layers/pipeline.py 88.23% 5 Missing and 5 partials ⚠️
src/maxtext/layers/decoders.py 80.00% 1 Missing ⚠️

📢 Thoughts on this report? Let us know!

@NuojCheng NuojCheng added pull ready draft Draft PR and removed pull ready labels Feb 9, 2026
@NuojCheng NuojCheng force-pushed the chengnuojin-pp-separate-weights branch from 28f98ff to 64b37ff Compare February 9, 2026 22:23
@NuojCheng NuojCheng force-pushed the chengnuojin-pp-separate-weights branch 8 times, most recently from a7d38d0 to 9a36099 Compare February 18, 2026 17:16
@NuojCheng NuojCheng force-pushed the chengnuojin-pp-separate-weights branch 2 times, most recently from d05c015 to e521a58 Compare February 24, 2026 23:14
gagika and others added 3 commits February 25, 2026 17:38
simple fix on debug sharding log

add all gather insertion per repeat

working all gather insertion

clean version fsdp+pp bug free

add bsw checkpoint

split bsw all gather into two

add custom vjp
@NuojCheng NuojCheng force-pushed the chengnuojin-pp-separate-weights branch from e521a58 to 94812d9 Compare February 25, 2026 17:38
@NuojCheng NuojCheng force-pushed the chengnuojin-pp-separate-weights branch from 51e6713 to 286e066 Compare February 26, 2026 00:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

draft Draft PR

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants