Skip to content

[OMNIML-5072] Triton fakequant adapter for N-quantizer per-expert path#1717

Draft
hychiang-git wants to merge 10 commits into
jennifchen/te_per_expertfrom
hungyuehc/omniml-5072
Draft

[OMNIML-5072] Triton fakequant adapter for N-quantizer per-expert path#1717
hychiang-git wants to merge 10 commits into
jennifchen/te_per_expertfrom
hungyuehc/omniml-5072

Conversation

@hychiang-git

Copy link
Copy Markdown
Contributor

What does this PR do?

Type of change: new feature

Implements OMNIML-5072, built on top of PR #1550 (the N-quantizer per-expert foundation, still WIP). Three additions:

  1. Triton fakequant dispatch for the N-quantizer per-expert path. Wires the single-launch grouped_axis0_fakequant kernel from PR #1671 (the One-Vec-quanitzer path) into _QuantTEGroupedLinear.te_grouped_quantized_linear_fn when _per_expert_weight_quantizer == True. Soft-gated behind _triton_kernels.IS_AVAILABLE and q._if_calib; falls back to the existing cuda_ext per-quantizer loop when the gate is False.
  2. Cached _gather_per_expert_amax helper eliminates per-forward O(N) Python overhead (walks the N submodules from PR WIP Support per expert amax in TEGroupedMLP #1550 and stacks the N scalar _amax buffers; lazily cached, invalidated from modelopt_post_restore).
  3. sharded_state_dict save + EP-aware load on _QuantMegatronTEGroupedLinear's N-quantizer case — gather-once-cache pattern adapted to N scalar _amax buffers across the EP group, so the dist-ckpt round-trip that PR [OMNIML-4998] Per-expert weight quantization on TEGroupedLinear #1671 ships for One-Vec-quanitzer also works on the N-quantizer foundation from PR WIP Support per expert amax in TEGroupedMLP #1550.

This PR is stacked on top of PR #1550 (jennifchen/te_per_expert, still WIP). The diff against main includes PR #1550's commits underneath; the OMNIML-5072-specific work is the top 6 commits (fd77b53d8..51b4c9226). Once PR #1550 lands, rebase onto main to shrink the review surface.

Usage

import os

# Enable the N-quantizer per-expert path on TEGroupedMLP. With this PR loaded,
# the Triton kernel dispatch activates automatically when triton is available.
os.environ["MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER"] = "1"

import modelopt.torch.quantization as mtq
# ... build a Megatron model with TEGroupedMLP ...
mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_fn)
# Each TEGroupedLinear now has N weight_quantizer_i submodules;
# forward uses Triton `grouped_axis0_fakequant` when available, otherwise
# falls back to cuda_ext per-quantizer.

Testing

GPU-validated on aws-cmh (B300, nemo:26.02 / nemo:25.11 containers):

  • Parity test (4 pytest cases, 6.14s): N-quantizer-Triton vs N-quantizer-cuda_ext on Ultra production shape (N=32, [5120, 8192] bf16) — fwd within 1-ULP floor, bwd bit-exact under pass_through_bwd=True. Test at tests/gpu/torch/quantization/plugins/test_te_grouped_triton_parity.py.
  • Microbench (4 cells: Nano / Super / Ultra at EP=4 and EP=8): N-quantizer-Triton ≈ One-Vec-quanitzer + Triton (PR [OMNIML-4998] Per-expert weight quantization on TEGroupedLinear #1671) within ~10%; the N-quantizer vs One-Vec-quanitzer topology has no measurable perf effect once both share the Triton kernel. Full matrix on OMNIML-5064.
  • Dist-ckpt round-trip (2 pytest cases, 92.99s) at TP=2/EP=2 for both FP8_DEFAULT_CFG and NVFP4_DEFAULT_CFG. Test at tests/gpu_megatron/torch/quantization/plugins/test_megatron.py::test_te_grouped_n_modules_sharded_state_dict.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: ✅ — Triton path is soft-gated; when _triton_kernels.IS_AVAILABLE is False or during calibration (q._if_calib), the original cuda_ext per-quantizer loop runs unchanged.
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅ — Reused the grouped_axis0_fakequant kernel module from PR [OMNIML-4998] Per-expert weight quantization on TEGroupedLinear #1671 (51b4c9226 reuses it via the new _GroupedAxis0FakeQuantFn autograd adapter). No new PIP dependencies.
  • Did you write any new necessary tests?: ✅ — Parity test + N-quantizer sharded_state_dict TP=2/EP=2 test.
  • Did you update Changelog?: ❌ — Stacked on top of WIP PR WIP Support per expert amax in TEGroupedMLP #1550; changelog entry deferred until both are merge-ready (will land with the rebase).
  • Did you get Claude approval on this PR?: ❌ — Pending; will run /claude review after rebase onto main.

Additional Information

Related work:

  • OMNIML-5072 — this ticket.
  • OMNIML-5064 — N-quantizer vs One-Vec-quanitzer comparison study; full microbench matrix here.
  • PR #1550 — N-quantizer foundation (WIP); this PR is stacked on top.
  • PR #1671 — One-Vec-quanitzer + Triton kernel (the kernel this PR reuses).

🤖 Generated with Claude Code

hychiang-git and others added 6 commits June 13, 2026 14:12
Copies the single-launch tensor-of-pointers fake-quant kernel module from
hungyuehc/omniml-4998-umbrella (kernel landed in 0bf4838, libdevice.rint
rounding refined in 1080e68). Kernel file is unchanged.

Wires the new module into modelopt.torch.kernels.quantization.gemm via the
existing IS_AVAILABLE/triton-import block in __init__.py.

The transformer_engine.py adapter that calls this kernel from the N-modules
per-expert path follows in the next commit.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…for N-modules per-expert path

Wires grouped_axis0_fakequant into _QuantTEGroupedLinear's forward when
_per_expert_weight_quantizer is on. Adds:

- _GroupedAxis0FakeQuantFn (torch.autograd.Function): single forward call
  for N expert weights; backward honors pass_through_bwd=True (identity)
  and dispatches to the Triton bwd kernel when False.
- _gather_per_expert_amax: stacks N weight_quantizer_i._amax scalars into
  a [N] fp32 vector matching the kernel's amax-input contract.
- _can_use_triton_per_expert_path: soft-gate on IS_AVAILABLE, all per-expert
  quantizers being TensorQuantizer with _amax set, and not currently
  calibrating (q._if_calib).
- te_grouped_quantized_linear_fn now branches: Triton path when gate passes;
  original per-quantizer cuda_ext loop otherwise.

Replaces N cuda_ext kernel launches with 1 Triton launch on the forward
hot path. No behavior change when the env var opt-in is off.

Untested at runtime yet; AC2 parity test (Ultra production shape, N=32,
pass_through_bwd=True) is the next step.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds tests/gpu/torch/quantization/plugins/test_te_grouped_triton_parity.py
mirroring nmm-sandbox/studies/omniml-5064/microbench/parity_a_vs_btriton.py
as a pytest in the modelopt tests/ surface.

Two checks at each parameterized shape (N=4/8/32):
- Forward parity within 1 ULP. Known rounding-mode mismatch floor between
  Triton's libdevice.rint and cuda_ext's banker's rounding at some bf16
  boundary values.
- Backward parity bit-exact under pass_through_bwd=True. Both paths must
  return grad_out unchanged regardless of forward kernel.

Plus a slow-marked Ultra production shape variant (N=32, [5120, 8192] bf16)
for full-scale validation. Marked slow because the unquantized + quantized
+ gradient copies of 32 expert weights at that shape use ~5 GB of GPU
memory; CI default-suite stays on the smaller parameterized cases.

Test not yet run — requires GPU + container with modelopt installed.
Expected to PASS per the matching standalone parity_a_vs_btriton.py output
on B's path after libdevice.rint refinement (1080e68): forward parity
within 1 ULP, backward bit-exact under pass_through_bwd=True.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…les per-expert path

Adds dist-checkpoint support to A's N-submodule per-expert weight quantization
path on _QuantMegatronTEGroupedLinear. Mirrors B's gather-once-cache pattern
(hungyuehc/omniml-4998-umbrella) adapted for A's storage layout — N separate
weight_quantizer_i._amax scalars across N submodules instead of one [N,1,1]
buffer on a single quantizer.

Methods added:
- _ep_group: returns the EP group if initialized and world_size > 1.
- _gather_global_per_expert_amax_n_modules: stacks N local scalar amaxes from
  the submodules, all-gathers across EP, returns [N_global]. None when the
  layer is not in per-expert mode.
- sharded_state_dict: caches the gathered global vector before delegating to
  super so the EP collective completes BEFORE Megatron's dist-checkpoint save
  fires default-PG ALLGATHER metadata exchanges (interleaving EP + default-PG
  collectives deadlocks NCCL — codified in
  [[feedback-no-custom-collectives-in-dist-ckpt-save]]).

Methods replaced:
- _process_quantizer_amax: emits the cached global [N_global] vector under
  every weight_quantizer_i._amax key in per-expert mode. Suboptimal disk
  usage (N copies of same vector per layer) but mirrors B's pattern and
  avoids surgery into the base-class state-dict iteration.
- _load_from_state_dict: preserves the existing _extra_state{i} filter and
  adds the per-expert narrow — pulls element (ep_rank * N_local + i) out of
  each saved [N_global] vector for the i-th local submodule. Falls through
  unchanged when v.numel() != global_size (legacy / EP=1 save format).

Validated via the AC4 parity test (next commit) at TP=2, EP=2 across
FP8_DEFAULT_CFG and NVFP4_DEFAULT_CFG.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…dict round-trip

Adds test_te_grouped_n_modules_sharded_state_dict parameterized over
FP8_DEFAULT_CFG and NVFP4_DEFAULT_CFG. Builds a TEGroupedMLP model at
TP=2 EP=2 num_moe_experts=4, quantizes with the
MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER=1 env-var path, saves dist-ckpt,
restores into a fresh meta-device model, asserts equivalence via the
existing sharded_state_dict_test_helper.

Mirrors the layout of B's test_te_grouped_per_expert_sharded_state_dict
(hidden_size=256, dist_workers fixture) but triggers A's env-var-gated
per-expert path instead of B's axis=0 quant_cfg knob.

Also cherry-picks the OMNIML-5030 sequence_parallel fix that A's branch
predates: get_mcore_gpt_model in tests/_test_utils/torch/megatron/models.py
gains a sequence_parallel parameter that threads through TransformerConfig,
and the non-hybrid call site in _gpt_model_provider passes
sequence_parallel=(tp_size > 1). Without this, Megatron-Core ValueError-s
during MoE + TP > 1 model construction. The hybrid path on A's branch
already had this fix.

Validated on aws-cmh slurm — 2 passed, 92.99s wall (job 537620).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…d N-loop overhead

The Triton dispatch in te_grouped_quantized_linear_fn calls
_gather_per_expert_amax on every forward. The gather walks N submodules
(O(N) Python overhead) and stacks N scalar _amax buffers into a [N] fp32
vector. The result is invariant across forwards (per-expert _amax does
not change outside calibration, and the gate already blocks the Triton
path when q._if_calib is True on any quantizer).

Caching the gathered tensor lazily on first call eliminates the per-forward
overhead. Invalidation hook _invalidate_per_expert_amax_cache is called from
modelopt_post_restore (where dist-ckpt reload may have changed _amax).

Measured impact on OMNIML-5064 microbench (Nemotron Nano EP=4, N=32):

  fwd_us:   1918 (no-cache)  ->  1244 (cached)   (35% drop)
  step_us:  3444 (no-cache)  ->  2785 (cached)   (19% drop)
  vs Btriton5 (B's path with same Triton kernel):
            ATriton-cached 1244 vs Btriton5 1208  ->  effectively tied
            ATriton-cached step 2785 vs Btriton5 step 2815 -> ATriton edges B

Without the cache the gap to Btriton5 grows with N (1.59x at N=32, 2.18x
at N=128, observed in the no-cache nano_ep4 + super_ep4 runs). With the
cache, the gap closes to within run-to-run noise.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 14, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai

coderabbitai Bot commented Jun 14, 2026

Copy link
Copy Markdown
Contributor

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 3579568a-9568-4404-99d5-f12c28c74eb9

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch hungyuehc/omniml-5072

Comment @coderabbitai help to get the list of available commands and usage tips.

hychiang-git and others added 3 commits June 13, 2026 21:23
The doc was a placeholder for B-triton's pre-validation tracking (commit
0bf4838 on PR #1671); validation has since landed and the file is no
longer load-bearing. The kernel module references it from a docstring
line that is stale and will be cleaned up by PR #1671's review pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…_quantizer_<N> siblings

The N-quantizers opt-in (MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER=1) adds
per-expert siblings as `weight_quantizer_<digits>` via add_module. The
canonical `*weight_quantizer` wildcard in stock QUANT_CFG entries does
not match these underscore-suffix names, so each sibling stayed at
default (`_disabled=True`). Forward called every sibling on every
calibration batch, but each call early-returned the inputs unchanged —
calibration was a no-op for routed-expert weights, `_amax` was never
populated, the save layer had nothing to write.

Extend `_normalize_fused_experts_quantizer_name` in conversion.py to
collapse the underscore-separator form `weight_quantizer_<N>` (and
`input_quantizer_<N>`) to the canonical singular name before fnmatch,
alongside the existing `weight_quantizer[s]?.N` dot-separator form
emitted by fused-experts plugins. Single-character widening: `\.` →
`[._]` in the separator class.

Stock QUANT_CFG patterns (`*weight_quantizer`, `*input_quantizer`) now
reach the N siblings, so they get enabled, configured, calibrated, and
saved like any other quantizer. Confirmed end-to-end on full
Nemotron-3-Nano PTQ smoke (NVFP4_DEFAULT_CFG, INT8_DEFAULT_CFG): saved
per-expert `_amax` keys went from 0 to 1472 (23 MoE layers x 2 linears
x 32 local experts).

Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
…er-expert path

The original per-expert save/load code in `_QuantMegatronTEGroupedLinear`
assumed each sibling's `_amax` was a scalar:
  - `_gather_global_per_expert_amax_n_modules` did `amax.view(())` then
    stacked into `[N_local]`, all-gather over EP -> `[N_global]`.
  - `_process_quantizer_amax` flattened the cached tensor via
    `cached.view(cached.numel())`.
  - `_load_from_state_dict` narrowed via
    `v.view(global_size)[offset+i].view(())` on a local filtered copy.

Holds for NVFP4_DEFAULT_CFG / FP8_DEFAULT_CFG (per-tensor weight `_amax`)
but breaks for any axis!=None config. INT8_DEFAULT_CFG (axis=0
per-channel) produces `_amax` shape `[out_per_expert, 1]` -- for
Nemotron-Nano fc1 that's `[1856, 1]`. The save side raised
`RuntimeError: shape '[]' is invalid for input of size 1856` at
`amax.view(())`.

Generalize all three methods to preserve native `_amax` shape:
  - Gather: stack along a new leading axis without flattening, produce
    `[N_local, *amax_shape]` locally and `[N_global, *amax_shape]` after
    EP all-gather.
  - Process: write the cached tensor verbatim under each
    `weight_quantizer_<i>._amax` key; fallback path drops the scalar-only
    assertion and emits the buffer in its native shape.
  - Load: detect `v.shape[0] == global_size` and slice dim 0, keeping
    trailing dims (`v[offset+local_i].contiguous()`). Fall back to the
    legacy flat-vector narrow when `v.numel() == global_size` but
    `v.dim() == 1`, for backward compat with previously-saved ckpts.

The load fix also has to mutate `state_dict` in place rather than passing
a filtered copy to `super()._load_from_state_dict`. PyTorch's
`Module.load_state_dict` recursion builds each child's filtered dict by
re-filtering the PARENT's `local_state_dict` AFTER the parent's
`_load_from_state_dict` returns; modifications to a local copy don't
propagate to the per-expert sibling children that actually own the
`_amax` buffers, and the strict-load size check rejects the un-narrowed
`[N_global, *amax_shape]` tensor as not matching the local
`[*amax_shape]` buffer. The `_extra_state{N}` suppression filter still
uses a local view because it only affects the parent's own strict-load
unexpected-key check.

Confirmed end-to-end on the full Nemotron-3-Nano PTQ + sharded ckpt
round-trip smoke:
  - NVFP4_DEFAULT_CFG (scalar `_amax`): 1472 per-expert keys round-trip,
    siblings carry distinct per-expert values.
  - INT8_DEFAULT_CFG (per-channel `[1856, 1]` `_amax`): 1472 per-expert
    keys round-trip with the `[1856, 1]` shape preserved on every
    sibling; per-expert max values distinct across siblings.

Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
@hychiang-git hychiang-git changed the base branch from main to jennifchen/te_per_expert June 15, 2026 16:37
@hychiang-git hychiang-git requested a review from jenchen13 June 15, 2026 16:39
…ensorrt=False)

The trtllm `fp4_quantize` CUDA op returns `weights_scaling_factor` in
cutlass-interleaved layout with `out_features` padded up to a multiple of 64
for GEMM tile alignment. For a Nemotron-Nano mamba `in_proj` weight of shape
`[10304, 2688]` this produces a `_scale` buffer of shape `[10368 * 168]`
flat — bytes 1,741,824 vs the canonical `[10304, 168]` shape = 1,731,072
elements. The shape mismatch breaks Megatron sharded-ckpt save/load:

  CheckpointingException: Global shape mismatch for loaded
  (torch.Size([1741824])) and expected ((10304, 168))
  for key decoder.layers.0.mixer.in_proj.weight_quantizer._scale

at the rq-export step (step 2's load of the compressed ckpt produced by
`mtq.compress` in step 1).

Pass `try_tensorrt=False` so `NVFP4QTensor.quantize` takes the canonical
Python path:

  per_block_amax = reduce_block_amax(input, block_sizes={-1: block_size})
  per_block_scale = per_block_amax / (6.0 * weights_scaling_factor_2)
  → shape == (out, in/block_size) ✓ matches load-side expectation

Trade-off: the cutlass NVFP4 GEMM kernel wants the interleaved layout for
inference. Re-interleaving happens at HF-export time via
`cutlass_fp4_scale_to_modelopt_fp4_scale` and its inverse (already wired
into the export-to-HF flow at unified_export_hf.py and the deserializer
in nvfp4_tensor.py line 279), so persisting the canonical form does not
block deployment — it just moves the layout transform from
save-time-only to export-time.

Smoke evidence on full Nemotron-3-Nano-30B (NVFP4_DEFAULT_CFG, EP=4):
  - Pre-fix: rq pipeline crashed at step 2 load with the global-shape
    mismatch above on `decoder.layers.0.mixer.in_proj.weight_quantizer._scale`.
  - Post-fix: step 2 load succeeds, model proceeds into HF export.
    A downstream packed-FP4 reshape bug in `unified_export_megatron.py::
    _qkv_slicing` (line 1068) is the next blocker but is unrelated to
    save/load layout — tracked separately.

Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant