Skip to content

Add: support input_shape_profile for trt-rtx ep#1782

Open
haoxiz-nvidia wants to merge 1 commit into
mainfrom
haoxiz/onnx-ptq-model-id
Open

Add: support input_shape_profile for trt-rtx ep#1782
haoxiz-nvidia wants to merge 1 commit into
mainfrom
haoxiz/onnx-ptq-model-id

Conversation

@haoxiz-nvidia

@haoxiz-nvidia haoxiz-nvidia commented Jun 22, 2026

Copy link
Copy Markdown
Contributor

What does this PR do?

Add support for onnx quantization and support model_id as input, which fix missing input_shpae_profile problem for some version of trt-rtx

Usage

python -m modelopt.onnx.quantization --onnx_path="path\to\model.onnx" --quantize_mode=int8 --output_path="path\to\output\model.onnx" --calibration_eps=NvTensorRtRtx --use_external_data_format --high_precision_dtype=fp32 --model_id="huggingface_model_id"

Testing

Tested on 4 popular llm models on all popular quantization method(int4, fp8, int8)

Before your PR is "Ready for review"

  • Is this change backward compatible?: ✅
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: ✅
  • Did you write any new necessary tests?: ❌
  • Did you update Changelog?: ❌
  • Did you get Claude approval on this PR?: N/A

Summary by CodeRabbit

  • New Features
    • Added model_id parameter to the ONNX quantization CLI and core quantization functions, allowing automatic generation of input shape profiles when not explicitly provided.

Signed-off-by: haoxiz <haoxiz@nvidia.com>
@haoxiz-nvidia haoxiz-nvidia self-assigned this Jun 22, 2026
@haoxiz-nvidia haoxiz-nvidia requested a review from a team as a code owner June 22, 2026 04:48
@coderabbitai

coderabbitai Bot commented Jun 22, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

Adds model_id as an optional parameter to the ONNX PTQ quantize() API and CLI. When input_shapes_profile is not provided and model_id is set, a new create_input_shapes_profile helper loads HF AutoConfig to derive per-EP min/opt/max attention shape dictionaries. The profile is then threaded through MHA/MatMul exclusion graph analysis and ORT calibration EP configuration.

Changes

Input Shape Profile Pipeline

Layer / File(s) Summary
Shape profile generation and ORT EP wiring
modelopt/onnx/quantization/ort_utils.py
Adds create_input_shapes_profile(model_id, calibration_eps) that loads HF AutoConfig, derives attention dimensions, and returns a per-EP list of min/opt/max shape dicts. Extends configure_ort with input_shapes_profile; when provided, builds execution_providers with per-EP (ep, provider_options) tuples for non-empty profile entries and passes them into TRT-guided options.
graph_utils inference session and exclusion propagation
modelopt/onnx/quantization/graph_utils.py
Adds input_shapes_profile: Sequence[dict[str, str]] | None = None to get_extended_model_outputs, find_nodes_from_matmul_to_exclude, _exclude_matmuls_by_inference, and find_nodes_from_mha_to_exclude, forwarding it into create_inference_session for both external-data and in-memory session paths.
INT8 and FP8 quantize() extension
modelopt/onnx/quantization/int8.py, modelopt/onnx/quantization/fp8.py
Adds input_shapes_profile parameter to int8.quantize() and fp8.quantize(), forwarding it into find_nodes_from_matmul_to_exclude (GEMV/TRT exclusion) and configure_ort.
Top-level quantize() and CLI wiring
modelopt/onnx/quantization/quantize.py, modelopt/onnx/quantization/__main__.py
Adds model_id: str | None = None to top-level quantize() and its docstring; auto-calls create_input_shapes_profile when model_id is set and input_shapes_profile is absent; threads profile into find_nodes_from_mha_to_exclude and the INT8/FP8 dispatch. Adds --model_id CLI argument.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes


Caution

Pre-merge checks failed

Please resolve all errors before merging. Addressing warnings is optional.

  • Ignore

❌ Failed checks (1 error)

Check name Status Explanation Resolution
Security Anti-Patterns ❌ Error PR adds three # nosec comments (B404, B603, B108) in ort_utils.py and benchmark.py to bypass Bandit security checks. SECURITY.md explicitly prohibits this pattern. Remove all # nosec comments. Add inline comments explaining WHY subprocess/temp file usage is safe, then request security review as required by SECURITY.md guidelines.
✅ Passed checks (5 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Add: support input_shape_profile for trt-rtx ep' directly relates to the main objective of adding input_shape_profile support for the TensorRT-RTX execution provider, which is the core feature implemented across multiple files in the changeset.
Docstring Coverage ✅ Passed Docstring coverage is 92.86% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch haoxiz/onnx-ptq-model-id

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions

Copy link
Copy Markdown
Contributor
PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1782/

Built to branch gh-pages at 2026-06-22 04:51 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/onnx/quantization/ort_utils.py`:
- Around line 595-603: The issue is that after _prepare_ep_list filters the
calibration_eps list to remove unavailable providers, the enumeration of
execution_providers uses indices from the filtered list instead of the original
list, causing the input_shapes_profile indices to misalign. To fix this,
enumerate over the original calibration_eps list instead of the filtered
execution_providers list when building the tuple pairs, using the index to
access input_shapes_profile correctly, and mapping each original ep to either
the profile (if available) or the filtered execution_providers equivalent.

In `@modelopt/onnx/quantization/quantize.py`:
- Around line 557-559: The input_shapes_profile is being created from
calibration_eps before it has been finalized by the update_trt_ep_support
function, causing potential sync issues downstream. Move the conditional block
that checks if input_shapes_profile is None and calls
create_input_shapes_profile with model_id and calibration_eps to execute after
update_trt_ep_support has been called, ensuring calibration_eps reflects the
final list of execution providers before generating the profile.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 19ed1a5a-2793-4772-b650-d3982467b520

📥 Commits

Reviewing files that changed from the base of the PR and between 9048d13 and db840b4.

📒 Files selected for processing (6)
  • modelopt/onnx/quantization/__main__.py
  • modelopt/onnx/quantization/fp8.py
  • modelopt/onnx/quantization/graph_utils.py
  • modelopt/onnx/quantization/int8.py
  • modelopt/onnx/quantization/ort_utils.py
  • modelopt/onnx/quantization/quantize.py

Comment on lines +595 to +603
execution_providers = _prepare_ep_list(calibration_eps)
if input_shapes_profile is not None:
assert len(calibration_eps) == len(input_shapes_profile), (
"Number of calibration EPs and number of input-shapes-profile don't match"
)
execution_providers = [
(ep, input_shapes_profile[idx]) if input_shapes_profile[idx] else ep
for idx, ep in enumerate(execution_providers)
]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
python - <<'PY'
# Demonstrates why enumerate(filtered_execution_providers) can misalign profiles.
calibration_eps = ["trt", "cuda:0"]  # requested order
input_shapes_profile = [{"trt_profile_min_shapes": "x"}, {}]

# Simulate `_prepare_ep_list(calibration_eps)` when TRT is unavailable:
execution_providers = [("CUDAExecutionProvider", {"device_id": 0})]

mapped = [
    (ep, input_shapes_profile[idx]) if input_shapes_profile[idx] else ep
    for idx, ep in enumerate(execution_providers)
]
print("Mapped providers:", mapped)
print("Expected: CUDA should not receive TRT profile.")
PY

Repository: NVIDIA/Model-Optimizer

Length of output: 212


🏁 Script executed:

# First, let's examine the actual code at the specified file and lines
head -610 modelopt/onnx/quantization/ort_utils.py | tail -30

Repository: NVIDIA/Model-Optimizer

Length of output: 1389


🏁 Script executed:

# Check the _prepare_ep_list function to understand what it does
rg -A 15 "def _prepare_ep_list" modelopt/onnx/quantization/ort_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 937


🏁 Script executed:

# Look for context around the function that contains lines 595-603
rg -B 20 "execution_providers = _prepare_ep_list" modelopt/onnx/quantization/ort_utils.py | head -50

Repository: NVIDIA/Model-Optimizer

Length of output: 798


🏁 Script executed:

# Search for where input_shapes_profile is used to understand the expected alignment
rg -n "input_shapes_profile" modelopt/onnx/quantization/ort_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 895


🏁 Script executed:

# Get the full function containing lines 595-603
rg -B 50 "execution_providers = _prepare_ep_list" modelopt/onnx/quantization/ort_utils.py | grep -A 50 "^[0-9]*-def "

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check line numbers more precisely
sed -n '555,620p' modelopt/onnx/quantization/ort_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2616


🏁 Script executed:

# Search for tests related to input_shapes_profile and calibrate_ort
fd -e py test | xargs rg -l "input_shapes_profile|calibrate_ort" 2>/dev/null

Repository: NVIDIA/Model-Optimizer

Length of output: 50


🏁 Script executed:

# Check _prepare_ep_list complete implementation
sed -n '/^def _prepare_ep_list/,/^def /p' modelopt/onnx/quantization/ort_utils.py | head -60

Repository: NVIDIA/Model-Optimizer

Length of output: 2115


🏁 Script executed:

# Search for tests in tests directory
find tests -name "*.py" -type f | xargs grep -l "input_shapes_profile\|calibrate_ort" 2>/dev/null || echo "No tests found matching pattern"

Repository: NVIDIA/Model-Optimizer

Length of output: 98


🏁 Script executed:

# Check if there are any tests for ort_utils specifically
find tests -name "*ort*" -type f

Repository: NVIDIA/Model-Optimizer

Length of output: 1267


🏁 Script executed:

# Look for any other calls to calibrate_ort to understand the impact
rg "calibrate_ort" --type py -A 3 -B 1

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check test_ort_patching.py for relevant tests
cat tests/gpu/onnx/test_ort_patching.py

Repository: NVIDIA/Model-Optimizer

Length of output: 28863


🏁 Script executed:

# Check test_onnx_export_cpu.py for input_shapes_profile usage
rg -B 5 -A 10 "input_shapes_profile" tests/unit/torch/quantization/test_onnx_export_cpu.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at the function signature and documentation for calibrate_ort
rg -B 10 "def calibrate_ort" modelopt/onnx/quantization/ort_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check if there are any docstrings or comments explaining the expected behavior
sed -n '535,560p' modelopt/onnx/quantization/ort_utils.py

Repository: NVIDIA/Model-Optimizer

Length of output: 1254


Profile indexing must use original EP list, not filtered providers.

Lines 597–601 use enumerate(execution_providers) after _prepare_ep_list(...) filters the input. When an early EP is dropped (e.g., TensorRT unavailable), the remaining providers are shifted, causing profile indices to misalign. For example, if calibration_eps=["trt", "cuda"] but TensorRT is unavailable, CUDA receives the TensorRT profile instead of an empty one, which can break ORT setup at runtime.

Suggested fix
-    execution_providers = _prepare_ep_list(calibration_eps)
-    if input_shapes_profile is not None:
-        assert len(calibration_eps) == len(input_shapes_profile), (
-            "Number of calibration EPs and number of input-shapes-profile don't match"
-        )
-        execution_providers = [
-            (ep, input_shapes_profile[idx]) if input_shapes_profile[idx] else ep
-            for idx, ep in enumerate(execution_providers)
-        ]
+    execution_providers = []
+    if input_shapes_profile is not None:
+        assert len(calibration_eps) == len(input_shapes_profile), (
+            "Number of calibration EPs and number of input-shapes-profile don't match"
+        )
+
+    for idx, requested_ep in enumerate(calibration_eps):
+        prepared = _prepare_ep_list([requested_ep])
+        if not prepared:
+            continue
+        ep = prepared[0]
+
+        profile = {} if input_shapes_profile is None else input_shapes_profile[idx]
+        if not profile:
+            execution_providers.append(ep)
+            continue
+
+        if isinstance(ep, tuple):
+            ep_name, ep_options = ep
+            execution_providers.append((ep_name, {**ep_options, **profile}))
+        else:
+            execution_providers.append((ep, profile))
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/quantization/ort_utils.py` around lines 595 - 603, The issue is
that after _prepare_ep_list filters the calibration_eps list to remove
unavailable providers, the enumeration of execution_providers uses indices from
the filtered list instead of the original list, causing the input_shapes_profile
indices to misalign. To fix this, enumerate over the original calibration_eps
list instead of the filtered execution_providers list when building the tuple
pairs, using the index to access input_shapes_profile correctly, and mapping
each original ep to either the profile (if available) or the filtered
execution_providers equivalent.

Comment on lines +557 to +559
if input_shapes_profile is None and model_id:
input_shapes_profile = create_input_shapes_profile(model_id, calibration_eps)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Build inferred profiles after calibration_eps is finalized.

At Line 557, profiles are inferred before update_trt_ep_support(...) updates calibration_eps later in this function. That can leave input_shapes_profile out-of-sync (length/order) with final EPs and trigger downstream failures.

💡 Suggested fix
-    if input_shapes_profile is None and model_id:
-        input_shapes_profile = create_input_shapes_profile(model_id, calibration_eps)
@@
     trt_plugins = update_trt_ep_support(calibration_eps, has_dds_op, has_custom_op, trt_plugins)  # type: ignore[arg-type]
+
+    if input_shapes_profile is None and model_id:
+        input_shapes_profile = create_input_shapes_profile(model_id, calibration_eps)
+    elif input_shapes_profile is not None and len(input_shapes_profile) != len(calibration_eps):
+        raise ValueError(
+            "Number of calibration EPs and number of input-shapes-profile don't match"
+        )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/onnx/quantization/quantize.py` around lines 557 - 559, The
input_shapes_profile is being created from calibration_eps before it has been
finalized by the update_trt_ep_support function, causing potential sync issues
downstream. Move the conditional block that checks if input_shapes_profile is
None and calls create_input_shapes_profile with model_id and calibration_eps to
execute after update_trt_ep_support has been called, ensuring calibration_eps
reflects the final list of execution providers before generating the profile.

@codecov

codecov Bot commented Jun 22, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 28.57143% with 25 lines in your changes missing coverage. Please review.
✅ Project coverage is 75.69%. Comparing base (cfc823d) to head (db840b4).
⚠️ Report is 45 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/onnx/quantization/ort_utils.py 11.53% 23 Missing ⚠️
modelopt/onnx/quantization/graph_utils.py 66.66% 1 Missing ⚠️
modelopt/onnx/quantization/quantize.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1782      +/-   ##
==========================================
- Coverage   77.09%   75.69%   -1.41%     
==========================================
  Files         511      511              
  Lines       56168    58272    +2104     
==========================================
+ Hits        43302    44107     +805     
- Misses      12866    14165    +1299     
Flag Coverage Δ
examples 41.80% <14.28%> (-0.15%) ⬇️
gpu 57.67% <25.71%> (-0.64%) ⬇️
unit 54.41% <28.57%> (+0.07%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant