-
Notifications
You must be signed in to change notification settings - Fork 404
user provided bound for torchtrt compile when export dimension is unb… #4213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
afc27a7
143ba02
222f35c
93a1460
4047762
dc669ac
7544d23
69a0857
3c31af5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,11 +5,14 @@ | |
| import os | ||
| import platform | ||
| import warnings | ||
| from typing import Any, Collection, List, Optional, Sequence, Union | ||
| from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple, Union | ||
|
|
||
| import sympy | ||
| import torch | ||
| import torch.utils._pytree as pytree | ||
| from torch.export import ExportedProgram | ||
| from torch.fx.node import Target | ||
| from torch.utils._sympy.numbers import int_oo | ||
| from torch_tensorrt._Device import Device | ||
| from torch_tensorrt._enums import EngineCapability, dtype | ||
| from torch_tensorrt._features import ENABLED_FEATURES, needs_cross_compile | ||
|
|
@@ -898,6 +901,164 @@ def _insert_complex_io_adapters( | |
| partitioned_module.recompile() | ||
|
|
||
|
|
||
| def _build_user_symbol_bounds( | ||
| gm: torch.fx.GraphModule, | ||
| sample_arg_inputs: Sequence[Input], | ||
| sample_kwarg_inputs: dict[Any, Any], | ||
| ) -> Dict[sympy.Symbol, Tuple[int, int]]: | ||
| """Map ``sympy.Symbol -> (min, max)`` from dynamic ``Input``s, used to | ||
| fill ``Dim.DYNAMIC`` upper bounds without mutating ``ShapeEnv``. | ||
|
|
||
| Validates against finite exporter bounds: ``user_max > exp_max`` and | ||
| ``user_min < exp_min`` raise (TRT would reject those shapes at runtime); | ||
| a strict subset narrows the engine profile to the user's bounds (info | ||
| log only); the ``user_min=1, exp_min=2`` case warns -- it's PyTorch's | ||
| 0/1 specialization artifact, not a user error. | ||
| """ | ||
| placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"] | ||
|
|
||
| # Flatten args+kwargs in pytree order — guaranteed to match placeholder | ||
| # order by torch.export, so we can zip directly without name matching. | ||
| flat_inputs, _ = pytree.tree_flatten((list(sample_arg_inputs), sample_kwarg_inputs)) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should come from the exported program (i believe its called in_spec) |
||
|
|
||
| user_symbol_bounds: Dict[sympy.Symbol, Tuple[int, int]] = {} | ||
|
|
||
| for node, inp in zip(placeholders, flat_inputs): | ||
| if not (isinstance(inp, Input) and inp.shape_mode == Input._ShapeMode.DYNAMIC): | ||
| continue | ||
| fake_val = node.meta.get("val") | ||
| if not isinstance(fake_val, torch.Tensor): | ||
| continue | ||
|
|
||
| min_shape = inp.shape["min_shape"] | ||
| max_shape = inp.shape["max_shape"] | ||
|
|
||
| if len(fake_val.size()) != len(min_shape): | ||
| raise ValueError( | ||
| f"Input '{node.target}' has {len(fake_val.size())} dimensions in " | ||
| f"the exported program, but the provided Input specifies " | ||
| f"{len(min_shape)} dimensions. Ensure Input.min_shape, " | ||
| f"Input.opt_shape, and Input.max_shape each have " | ||
| f"{len(fake_val.size())} entries." | ||
| ) | ||
|
|
||
| for d, dim in enumerate(fake_val.size()): | ||
| if not isinstance(dim, torch.SymInt): | ||
| if min_shape[d] != dim or max_shape[d] != dim: | ||
| raise ValueError( | ||
| f"Input '{node.target}' dim {d} is static (size={int(dim)}) " | ||
| f"in the exported program, but the provided Input has " | ||
| f"min_shape[{d}]={min_shape[d]}, max_shape[{d}]={max_shape[d]}. " | ||
| f"Static dimensions must be fixed." | ||
| ) | ||
| continue | ||
| expr = dim.node.expr | ||
| # Composite exprs (e.g. ``2*s0``) are recomputed by | ||
| # ``ShapeEnv.bound_sympy``; overriding them directly would lie. | ||
| if not isinstance(expr, sympy.Symbol): | ||
|
apbose marked this conversation as resolved.
|
||
| logger.debug( | ||
| "Input '%s' dim %d is a composite symbolic expression (%s) " | ||
| "bounded by another dynamic dimension; its range will be " | ||
| "derived from constituent symbols via bound_sympy.", | ||
| node.target, | ||
| d, | ||
| expr, | ||
| ) | ||
| continue | ||
| if expr in user_symbol_bounds: | ||
| continue | ||
| user_min = int(min_shape[d]) | ||
| user_max = int(max_shape[d]) | ||
| user_symbol_bounds[expr] = (user_min, user_max) | ||
| logger.debug( | ||
| "Recorded user-supplied bounds for %s: [%d, %d]", | ||
| expr, | ||
| user_min, | ||
| user_max, | ||
| ) | ||
|
|
||
| # The exported program may already bound this symbol to a finite | ||
| # range (e.g. Dim("batch", min=10, max=20)). The compiled TRT | ||
| # engine's optimization profile follows that range; any shape | ||
| # outside it is rejected by TensorRT at runtime | ||
| # (IExecutionContext::setInputShape "satisfyProfile" check). | ||
| # Validate the user's Input range against it here -- at compile | ||
| # time -- before they hit that opaque runtime error on a shape | ||
| # they explicitly declared in Input.min_shape / Input.max_shape. | ||
| shape_env = getattr(dim.node, "shape_env", None) | ||
| if shape_env is None: | ||
| continue | ||
| exp_range = shape_env.var_to_range.get(expr) | ||
| if exp_range is None: | ||
| continue | ||
| exp_lower = exp_range.lower | ||
| exp_upper = exp_range.upper | ||
| exp_max_unbounded = exp_upper is int_oo or exp_upper == sympy.oo | ||
| if exp_max_unbounded: | ||
| # Dim.DYNAMIC: user fills the gap (intended use). | ||
| continue | ||
| try: | ||
| exp_min = int(exp_lower) | ||
| exp_max = int(exp_upper) | ||
| except (TypeError, ValueError): | ||
| continue | ||
| if user_min == exp_min and user_max == exp_max: | ||
| continue | ||
|
apbose marked this conversation as resolved.
|
||
|
|
||
| mismatch = ( | ||
| f"Dynamic dimension '{expr}': " | ||
| f"Input range [{user_min}, {user_max}] vs " | ||
| f"exported program range [{exp_min}, {exp_max}]." | ||
| ) | ||
|
|
||
| if user_max > exp_max: | ||
| raise ValueError( | ||
| f"{mismatch} Input.max_shape ({user_max}) exceeds the " | ||
| f"exported program's max ({exp_max}). The program was " | ||
| f"exported with this dimension bounded to " | ||
| f"[{exp_min}, {exp_max}], so the compiled TensorRT engine " | ||
| f"cannot accept shapes above {exp_max}. Either re-export " | ||
| f"with Dim('{expr}', max={user_max}) or set " | ||
| f"Input.max_shape <= {exp_max}." | ||
| ) | ||
|
|
||
| if user_min < exp_min: | ||
| # 1->2 is the 0/1 specialization artifact, not a user error. | ||
| if user_min == 1 and exp_min == 2: | ||
| logger.warning( | ||
| "%s Input.min_shape=1 but the exported program's min " | ||
| "is 2 (PyTorch 0/1 specialization -- Dim(min=1) is " | ||
| "recorded as min=2). The compiled engine's min will " | ||
| "be 2.", | ||
| mismatch, | ||
| ) | ||
| continue | ||
| raise ValueError( | ||
| f"{mismatch} Input.min_shape ({user_min}) is below the " | ||
| f"exported program's min ({exp_min}). The program was " | ||
| f"exported with this dimension bounded to " | ||
| f"[{exp_min}, {exp_max}], so the compiled TensorRT engine " | ||
| f"cannot accept shapes below {exp_min}. Either re-export " | ||
| f"with Dim('{expr}', min={user_min}) or set " | ||
| f"Input.min_shape >= {exp_min}." | ||
| ) | ||
|
|
||
| # Strict subset: engine profile narrows to the user's bounds | ||
| # (applied in ``extract_var_range_info``). Not a warning -- the | ||
| # user got exactly what they asked for. | ||
| logger.info( | ||
| "%s Narrowing engine profile to user bounds [%d, %d] " | ||
| "(exported program range was [%d, %d]).", | ||
| mismatch, | ||
| user_min, | ||
| user_max, | ||
| exp_min, | ||
| exp_max, | ||
| ) | ||
|
|
||
| return user_symbol_bounds | ||
|
|
||
|
|
||
| @fn_supports_debugger # type: ignore[misc] | ||
| def compile_module( | ||
| gm: torch.fx.GraphModule, | ||
|
|
@@ -929,6 +1090,12 @@ def compile_module( | |
| if sample_kwarg_inputs is None: | ||
| sample_kwarg_inputs = {} | ||
|
|
||
| # Forwarded to the partitioner to fill Dim.DYNAMIC upper bounds. | ||
| # Read-only w.r.t. ShapeEnv so range_constraints survive save/re-export. | ||
| user_symbol_bounds = _build_user_symbol_bounds( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. instead of trying to handle args, kwargs, why not operate on the flattened pytree inputs? If you really do need to do this work this early then you can always call in_spec to flatten the args, kwargs. Then the order is deterministic and you arent doing name matching |
||
| gm, sample_arg_inputs, sample_kwarg_inputs | ||
| ) | ||
|
|
||
| # Configure user compilation settings to converters. | ||
| CONVERTERS.set_compilation_settings(settings) | ||
|
|
||
|
|
@@ -1110,7 +1277,9 @@ def preserve_module_specs( | |
| ) | ||
|
|
||
| # Get the submodule inputs for min, opt, max shapes of the graph inputs | ||
| submodule_inputs = partitioning.construct_submodule_inputs(submodule) | ||
| submodule_inputs = partitioning.construct_submodule_inputs( | ||
| submodule, user_symbol_bounds=user_symbol_bounds | ||
| ) | ||
|
|
||
| assert submodule_inputs is not None | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same thing here, there should be a flat list of inputs to the graph
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we would need the placeholders for the faketensor values right?