[OMNIML-5072] Triton fakequant adapter for N-quantizer per-expert path#1717
Draft
hychiang-git wants to merge 10 commits into
Draft
[OMNIML-5072] Triton fakequant adapter for N-quantizer per-expert path#1717hychiang-git wants to merge 10 commits into
hychiang-git wants to merge 10 commits into
Conversation
Copies the single-launch tensor-of-pointers fake-quant kernel module from hungyuehc/omniml-4998-umbrella (kernel landed in 0bf4838, libdevice.rint rounding refined in 1080e68). Kernel file is unchanged. Wires the new module into modelopt.torch.kernels.quantization.gemm via the existing IS_AVAILABLE/triton-import block in __init__.py. The transformer_engine.py adapter that calls this kernel from the N-modules per-expert path follows in the next commit. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…for N-modules per-expert path Wires grouped_axis0_fakequant into _QuantTEGroupedLinear's forward when _per_expert_weight_quantizer is on. Adds: - _GroupedAxis0FakeQuantFn (torch.autograd.Function): single forward call for N expert weights; backward honors pass_through_bwd=True (identity) and dispatches to the Triton bwd kernel when False. - _gather_per_expert_amax: stacks N weight_quantizer_i._amax scalars into a [N] fp32 vector matching the kernel's amax-input contract. - _can_use_triton_per_expert_path: soft-gate on IS_AVAILABLE, all per-expert quantizers being TensorQuantizer with _amax set, and not currently calibrating (q._if_calib). - te_grouped_quantized_linear_fn now branches: Triton path when gate passes; original per-quantizer cuda_ext loop otherwise. Replaces N cuda_ext kernel launches with 1 Triton launch on the forward hot path. No behavior change when the env var opt-in is off. Untested at runtime yet; AC2 parity test (Ultra production shape, N=32, pass_through_bwd=True) is the next step. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds tests/gpu/torch/quantization/plugins/test_te_grouped_triton_parity.py mirroring nmm-sandbox/studies/omniml-5064/microbench/parity_a_vs_btriton.py as a pytest in the modelopt tests/ surface. Two checks at each parameterized shape (N=4/8/32): - Forward parity within 1 ULP. Known rounding-mode mismatch floor between Triton's libdevice.rint and cuda_ext's banker's rounding at some bf16 boundary values. - Backward parity bit-exact under pass_through_bwd=True. Both paths must return grad_out unchanged regardless of forward kernel. Plus a slow-marked Ultra production shape variant (N=32, [5120, 8192] bf16) for full-scale validation. Marked slow because the unquantized + quantized + gradient copies of 32 expert weights at that shape use ~5 GB of GPU memory; CI default-suite stays on the smaller parameterized cases. Test not yet run — requires GPU + container with modelopt installed. Expected to PASS per the matching standalone parity_a_vs_btriton.py output on B's path after libdevice.rint refinement (1080e68): forward parity within 1 ULP, backward bit-exact under pass_through_bwd=True. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…les per-expert path
Adds dist-checkpoint support to A's N-submodule per-expert weight quantization
path on _QuantMegatronTEGroupedLinear. Mirrors B's gather-once-cache pattern
(hungyuehc/omniml-4998-umbrella) adapted for A's storage layout — N separate
weight_quantizer_i._amax scalars across N submodules instead of one [N,1,1]
buffer on a single quantizer.
Methods added:
- _ep_group: returns the EP group if initialized and world_size > 1.
- _gather_global_per_expert_amax_n_modules: stacks N local scalar amaxes from
the submodules, all-gathers across EP, returns [N_global]. None when the
layer is not in per-expert mode.
- sharded_state_dict: caches the gathered global vector before delegating to
super so the EP collective completes BEFORE Megatron's dist-checkpoint save
fires default-PG ALLGATHER metadata exchanges (interleaving EP + default-PG
collectives deadlocks NCCL — codified in
[[feedback-no-custom-collectives-in-dist-ckpt-save]]).
Methods replaced:
- _process_quantizer_amax: emits the cached global [N_global] vector under
every weight_quantizer_i._amax key in per-expert mode. Suboptimal disk
usage (N copies of same vector per layer) but mirrors B's pattern and
avoids surgery into the base-class state-dict iteration.
- _load_from_state_dict: preserves the existing _extra_state{i} filter and
adds the per-expert narrow — pulls element (ep_rank * N_local + i) out of
each saved [N_global] vector for the i-th local submodule. Falls through
unchanged when v.numel() != global_size (legacy / EP=1 save format).
Validated via the AC4 parity test (next commit) at TP=2, EP=2 across
FP8_DEFAULT_CFG and NVFP4_DEFAULT_CFG.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…dict round-trip Adds test_te_grouped_n_modules_sharded_state_dict parameterized over FP8_DEFAULT_CFG and NVFP4_DEFAULT_CFG. Builds a TEGroupedMLP model at TP=2 EP=2 num_moe_experts=4, quantizes with the MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER=1 env-var path, saves dist-ckpt, restores into a fresh meta-device model, asserts equivalence via the existing sharded_state_dict_test_helper. Mirrors the layout of B's test_te_grouped_per_expert_sharded_state_dict (hidden_size=256, dist_workers fixture) but triggers A's env-var-gated per-expert path instead of B's axis=0 quant_cfg knob. Also cherry-picks the OMNIML-5030 sequence_parallel fix that A's branch predates: get_mcore_gpt_model in tests/_test_utils/torch/megatron/models.py gains a sequence_parallel parameter that threads through TransformerConfig, and the non-hybrid call site in _gpt_model_provider passes sequence_parallel=(tp_size > 1). Without this, Megatron-Core ValueError-s during MoE + TP > 1 model construction. The hybrid path on A's branch already had this fix. Validated on aws-cmh slurm — 2 passed, 92.99s wall (job 537620). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…d N-loop overhead
The Triton dispatch in te_grouped_quantized_linear_fn calls
_gather_per_expert_amax on every forward. The gather walks N submodules
(O(N) Python overhead) and stacks N scalar _amax buffers into a [N] fp32
vector. The result is invariant across forwards (per-expert _amax does
not change outside calibration, and the gate already blocks the Triton
path when q._if_calib is True on any quantizer).
Caching the gathered tensor lazily on first call eliminates the per-forward
overhead. Invalidation hook _invalidate_per_expert_amax_cache is called from
modelopt_post_restore (where dist-ckpt reload may have changed _amax).
Measured impact on OMNIML-5064 microbench (Nemotron Nano EP=4, N=32):
fwd_us: 1918 (no-cache) -> 1244 (cached) (35% drop)
step_us: 3444 (no-cache) -> 2785 (cached) (19% drop)
vs Btriton5 (B's path with same Triton kernel):
ATriton-cached 1244 vs Btriton5 1208 -> effectively tied
ATriton-cached step 2785 vs Btriton5 step 2815 -> ATriton edges B
Without the cache the gap to Btriton5 grows with N (1.59x at N=32, 2.18x
at N=128, observed in the no-cache nano_ep4 + super_ep4 runs). With the
cache, the gap closes to within run-to-run noise.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Contributor
|
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
The doc was a placeholder for B-triton's pre-validation tracking (commit 0bf4838 on PR #1671); validation has since landed and the file is no longer load-bearing. The kernel module references it from a docstring line that is stale and will be cleaned up by PR #1671's review pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…_quantizer_<N> siblings The N-quantizers opt-in (MODELOPT_TEGROUPED_PER_EXPERT_QUANTIZER=1) adds per-expert siblings as `weight_quantizer_<digits>` via add_module. The canonical `*weight_quantizer` wildcard in stock QUANT_CFG entries does not match these underscore-suffix names, so each sibling stayed at default (`_disabled=True`). Forward called every sibling on every calibration batch, but each call early-returned the inputs unchanged — calibration was a no-op for routed-expert weights, `_amax` was never populated, the save layer had nothing to write. Extend `_normalize_fused_experts_quantizer_name` in conversion.py to collapse the underscore-separator form `weight_quantizer_<N>` (and `input_quantizer_<N>`) to the canonical singular name before fnmatch, alongside the existing `weight_quantizer[s]?.N` dot-separator form emitted by fused-experts plugins. Single-character widening: `\.` → `[._]` in the separator class. Stock QUANT_CFG patterns (`*weight_quantizer`, `*input_quantizer`) now reach the N siblings, so they get enabled, configured, calibrated, and saved like any other quantizer. Confirmed end-to-end on full Nemotron-3-Nano PTQ smoke (NVFP4_DEFAULT_CFG, INT8_DEFAULT_CFG): saved per-expert `_amax` keys went from 0 to 1472 (23 MoE layers x 2 linears x 32 local experts). Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
…er-expert path
The original per-expert save/load code in `_QuantMegatronTEGroupedLinear`
assumed each sibling's `_amax` was a scalar:
- `_gather_global_per_expert_amax_n_modules` did `amax.view(())` then
stacked into `[N_local]`, all-gather over EP -> `[N_global]`.
- `_process_quantizer_amax` flattened the cached tensor via
`cached.view(cached.numel())`.
- `_load_from_state_dict` narrowed via
`v.view(global_size)[offset+i].view(())` on a local filtered copy.
Holds for NVFP4_DEFAULT_CFG / FP8_DEFAULT_CFG (per-tensor weight `_amax`)
but breaks for any axis!=None config. INT8_DEFAULT_CFG (axis=0
per-channel) produces `_amax` shape `[out_per_expert, 1]` -- for
Nemotron-Nano fc1 that's `[1856, 1]`. The save side raised
`RuntimeError: shape '[]' is invalid for input of size 1856` at
`amax.view(())`.
Generalize all three methods to preserve native `_amax` shape:
- Gather: stack along a new leading axis without flattening, produce
`[N_local, *amax_shape]` locally and `[N_global, *amax_shape]` after
EP all-gather.
- Process: write the cached tensor verbatim under each
`weight_quantizer_<i>._amax` key; fallback path drops the scalar-only
assertion and emits the buffer in its native shape.
- Load: detect `v.shape[0] == global_size` and slice dim 0, keeping
trailing dims (`v[offset+local_i].contiguous()`). Fall back to the
legacy flat-vector narrow when `v.numel() == global_size` but
`v.dim() == 1`, for backward compat with previously-saved ckpts.
The load fix also has to mutate `state_dict` in place rather than passing
a filtered copy to `super()._load_from_state_dict`. PyTorch's
`Module.load_state_dict` recursion builds each child's filtered dict by
re-filtering the PARENT's `local_state_dict` AFTER the parent's
`_load_from_state_dict` returns; modifications to a local copy don't
propagate to the per-expert sibling children that actually own the
`_amax` buffers, and the strict-load size check rejects the un-narrowed
`[N_global, *amax_shape]` tensor as not matching the local
`[*amax_shape]` buffer. The `_extra_state{N}` suppression filter still
uses a local view because it only affects the parent's own strict-load
unexpected-key check.
Confirmed end-to-end on the full Nemotron-3-Nano PTQ + sharded ckpt
round-trip smoke:
- NVFP4_DEFAULT_CFG (scalar `_amax`): 1472 per-expert keys round-trip,
siblings carry distinct per-expert values.
- INT8_DEFAULT_CFG (per-channel `[1856, 1]` `_amax`): 1472 per-expert
keys round-trip with the `[1856, 1]` shape preserved on every
sibling; per-expert max values distinct across siblings.
Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
…ensorrt=False)
The trtllm `fp4_quantize` CUDA op returns `weights_scaling_factor` in
cutlass-interleaved layout with `out_features` padded up to a multiple of 64
for GEMM tile alignment. For a Nemotron-Nano mamba `in_proj` weight of shape
`[10304, 2688]` this produces a `_scale` buffer of shape `[10368 * 168]`
flat — bytes 1,741,824 vs the canonical `[10304, 168]` shape = 1,731,072
elements. The shape mismatch breaks Megatron sharded-ckpt save/load:
CheckpointingException: Global shape mismatch for loaded
(torch.Size([1741824])) and expected ((10304, 168))
for key decoder.layers.0.mixer.in_proj.weight_quantizer._scale
at the rq-export step (step 2's load of the compressed ckpt produced by
`mtq.compress` in step 1).
Pass `try_tensorrt=False` so `NVFP4QTensor.quantize` takes the canonical
Python path:
per_block_amax = reduce_block_amax(input, block_sizes={-1: block_size})
per_block_scale = per_block_amax / (6.0 * weights_scaling_factor_2)
→ shape == (out, in/block_size) ✓ matches load-side expectation
Trade-off: the cutlass NVFP4 GEMM kernel wants the interleaved layout for
inference. Re-interleaving happens at HF-export time via
`cutlass_fp4_scale_to_modelopt_fp4_scale` and its inverse (already wired
into the export-to-HF flow at unified_export_hf.py and the deserializer
in nvfp4_tensor.py line 279), so persisting the canonical form does not
block deployment — it just moves the layout transform from
save-time-only to export-time.
Smoke evidence on full Nemotron-3-Nano-30B (NVFP4_DEFAULT_CFG, EP=4):
- Pre-fix: rq pipeline crashed at step 2 load with the global-shape
mismatch above on `decoder.layers.0.mixer.in_proj.weight_quantizer._scale`.
- Post-fix: step 2 load succeeds, model proceeds into HF export.
A downstream packed-FP4 reshape bug in `unified_export_megatron.py::
_qkv_slicing` (line 1068) is the next blocker but is unrelated to
save/load layout — tracked separately.
Signed-off-by: Hung-Yueh Chiang <hungyuehc@nvidia.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Type of change: new feature
Implements OMNIML-5072, built on top of PR #1550 (the N-quantizer per-expert foundation, still WIP). Three additions:
grouped_axis0_fakequantkernel from PR #1671 (the One-Vec-quanitzer path) into_QuantTEGroupedLinear.te_grouped_quantized_linear_fnwhen_per_expert_weight_quantizer == True. Soft-gated behind_triton_kernels.IS_AVAILABLEandq._if_calib; falls back to the existingcuda_extper-quantizer loop when the gate is False._gather_per_expert_amaxhelper eliminates per-forward O(N) Python overhead (walks the N submodules from PR WIP Support per expert amax in TEGroupedMLP #1550 and stacks the N scalar_amaxbuffers; lazily cached, invalidated frommodelopt_post_restore).sharded_state_dictsave + EP-aware load on_QuantMegatronTEGroupedLinear's N-quantizer case — gather-once-cache pattern adapted to N scalar_amaxbuffers across the EP group, so the dist-ckpt round-trip that PR [OMNIML-4998] Per-expert weight quantization on TEGroupedLinear #1671 ships for One-Vec-quanitzer also works on the N-quantizer foundation from PR WIP Support per expert amax in TEGroupedMLP #1550.This PR is stacked on top of PR #1550 (
jennifchen/te_per_expert, still WIP). The diff againstmainincludes PR #1550's commits underneath; the OMNIML-5072-specific work is the top 6 commits (fd77b53d8..51b4c9226). Once PR #1550 lands, rebase ontomainto shrink the review surface.Usage
Testing
GPU-validated on aws-cmh (B300, nemo:26.02 / nemo:25.11 containers):
pass_through_bwd=True. Test attests/gpu/torch/quantization/plugins/test_te_grouped_triton_parity.py.FP8_DEFAULT_CFGandNVFP4_DEFAULT_CFG. Test attests/gpu_megatron/torch/quantization/plugins/test_megatron.py::test_te_grouped_n_modules_sharded_state_dict.Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.)._triton_kernels.IS_AVAILABLEis False or during calibration (q._if_calib), the originalcuda_extper-quantizer loop runs unchanged.CONTRIBUTING.md: ✅ — Reused thegrouped_axis0_fakequantkernel module from PR [OMNIML-4998] Per-expert weight quantization on TEGroupedLinear #1671 (51b4c9226reuses it via the new_GroupedAxis0FakeQuantFnautograd adapter). No new PIP dependencies.sharded_state_dictTP=2/EP=2 test./claude reviewafter rebase ontomain.Additional Information
Related work:
🤖 Generated with Claude Code