Skip to content
173 changes: 171 additions & 2 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
import os
import platform
import warnings
from typing import Any, Collection, List, Optional, Sequence, Union
from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple, Union

import sympy
import torch
import torch.utils._pytree as pytree
from torch.export import ExportedProgram
from torch.fx.node import Target
from torch.utils._sympy.numbers import int_oo
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability, dtype
from torch_tensorrt._features import ENABLED_FEATURES, needs_cross_compile
Expand Down Expand Up @@ -898,6 +901,164 @@ def _insert_complex_io_adapters(
partitioned_module.recompile()


def _build_user_symbol_bounds(
gm: torch.fx.GraphModule,
sample_arg_inputs: Sequence[Input],
sample_kwarg_inputs: dict[Any, Any],
) -> Dict[sympy.Symbol, Tuple[int, int]]:
"""Map ``sympy.Symbol -> (min, max)`` from dynamic ``Input``s, used to
fill ``Dim.DYNAMIC`` upper bounds without mutating ``ShapeEnv``.

Validates against finite exporter bounds: ``user_max > exp_max`` and
``user_min < exp_min`` raise (TRT would reject those shapes at runtime);
a strict subset narrows the engine profile to the user's bounds (info
log only); the ``user_min=1, exp_min=2`` case warns -- it's PyTorch's
0/1 specialization artifact, not a user error.
"""
placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing here, there should be a flat list of inputs to the graph

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we would need the placeholders for the faketensor values right?


# Flatten args+kwargs in pytree order — guaranteed to match placeholder
# order by torch.export, so we can zip directly without name matching.
flat_inputs, _ = pytree.tree_flatten((list(sample_arg_inputs), sample_kwarg_inputs))

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should come from the exported program (i believe its called in_spec)


user_symbol_bounds: Dict[sympy.Symbol, Tuple[int, int]] = {}

for node, inp in zip(placeholders, flat_inputs):
if not (isinstance(inp, Input) and inp.shape_mode == Input._ShapeMode.DYNAMIC):
continue
fake_val = node.meta.get("val")
if not isinstance(fake_val, torch.Tensor):
continue

min_shape = inp.shape["min_shape"]
max_shape = inp.shape["max_shape"]

if len(fake_val.size()) != len(min_shape):
raise ValueError(
f"Input '{node.target}' has {len(fake_val.size())} dimensions in "
f"the exported program, but the provided Input specifies "
f"{len(min_shape)} dimensions. Ensure Input.min_shape, "
f"Input.opt_shape, and Input.max_shape each have "
f"{len(fake_val.size())} entries."
)

for d, dim in enumerate(fake_val.size()):
if not isinstance(dim, torch.SymInt):
if min_shape[d] != dim or max_shape[d] != dim:
raise ValueError(
f"Input '{node.target}' dim {d} is static (size={int(dim)}) "
f"in the exported program, but the provided Input has "
f"min_shape[{d}]={min_shape[d]}, max_shape[{d}]={max_shape[d]}. "
f"Static dimensions must be fixed."
)
continue
expr = dim.node.expr
# Composite exprs (e.g. ``2*s0``) are recomputed by
# ``ShapeEnv.bound_sympy``; overriding them directly would lie.
if not isinstance(expr, sympy.Symbol):
Comment thread
apbose marked this conversation as resolved.
logger.debug(
"Input '%s' dim %d is a composite symbolic expression (%s) "
"bounded by another dynamic dimension; its range will be "
"derived from constituent symbols via bound_sympy.",
node.target,
d,
expr,
)
continue
if expr in user_symbol_bounds:
continue
user_min = int(min_shape[d])
user_max = int(max_shape[d])
user_symbol_bounds[expr] = (user_min, user_max)
logger.debug(
"Recorded user-supplied bounds for %s: [%d, %d]",
expr,
user_min,
user_max,
)

# The exported program may already bound this symbol to a finite
# range (e.g. Dim("batch", min=10, max=20)). The compiled TRT
# engine's optimization profile follows that range; any shape
# outside it is rejected by TensorRT at runtime
# (IExecutionContext::setInputShape "satisfyProfile" check).
# Validate the user's Input range against it here -- at compile
# time -- before they hit that opaque runtime error on a shape
# they explicitly declared in Input.min_shape / Input.max_shape.
shape_env = getattr(dim.node, "shape_env", None)
if shape_env is None:
continue
exp_range = shape_env.var_to_range.get(expr)
if exp_range is None:
continue
exp_lower = exp_range.lower
exp_upper = exp_range.upper
exp_max_unbounded = exp_upper is int_oo or exp_upper == sympy.oo
if exp_max_unbounded:
# Dim.DYNAMIC: user fills the gap (intended use).
continue
try:
exp_min = int(exp_lower)
exp_max = int(exp_upper)
except (TypeError, ValueError):
continue
if user_min == exp_min and user_max == exp_max:
continue
Comment thread
apbose marked this conversation as resolved.

mismatch = (
f"Dynamic dimension '{expr}': "
f"Input range [{user_min}, {user_max}] vs "
f"exported program range [{exp_min}, {exp_max}]."
)

if user_max > exp_max:
raise ValueError(
f"{mismatch} Input.max_shape ({user_max}) exceeds the "
f"exported program's max ({exp_max}). The program was "
f"exported with this dimension bounded to "
f"[{exp_min}, {exp_max}], so the compiled TensorRT engine "
f"cannot accept shapes above {exp_max}. Either re-export "
f"with Dim('{expr}', max={user_max}) or set "
f"Input.max_shape <= {exp_max}."
)

if user_min < exp_min:
# 1->2 is the 0/1 specialization artifact, not a user error.
if user_min == 1 and exp_min == 2:
logger.warning(
"%s Input.min_shape=1 but the exported program's min "
"is 2 (PyTorch 0/1 specialization -- Dim(min=1) is "
"recorded as min=2). The compiled engine's min will "
"be 2.",
mismatch,
)
continue
raise ValueError(
f"{mismatch} Input.min_shape ({user_min}) is below the "
f"exported program's min ({exp_min}). The program was "
f"exported with this dimension bounded to "
f"[{exp_min}, {exp_max}], so the compiled TensorRT engine "
f"cannot accept shapes below {exp_min}. Either re-export "
f"with Dim('{expr}', min={user_min}) or set "
f"Input.min_shape >= {exp_min}."
)

# Strict subset: engine profile narrows to the user's bounds
# (applied in ``extract_var_range_info``). Not a warning -- the
# user got exactly what they asked for.
logger.info(
"%s Narrowing engine profile to user bounds [%d, %d] "
"(exported program range was [%d, %d]).",
mismatch,
user_min,
user_max,
exp_min,
exp_max,
)

return user_symbol_bounds


@fn_supports_debugger # type: ignore[misc]
def compile_module(
gm: torch.fx.GraphModule,
Expand Down Expand Up @@ -929,6 +1090,12 @@ def compile_module(
if sample_kwarg_inputs is None:
sample_kwarg_inputs = {}

# Forwarded to the partitioner to fill Dim.DYNAMIC upper bounds.
# Read-only w.r.t. ShapeEnv so range_constraints survive save/re-export.
user_symbol_bounds = _build_user_symbol_bounds(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of trying to handle args, kwargs, why not operate on the flattened pytree inputs? If you really do need to do this work this early then you can always call in_spec to flatten the args, kwargs. Then the order is deterministic and you arent doing name matching

gm, sample_arg_inputs, sample_kwarg_inputs
)

# Configure user compilation settings to converters.
CONVERTERS.set_compilation_settings(settings)

Expand Down Expand Up @@ -1110,7 +1277,9 @@ def preserve_module_specs(
)

# Get the submodule inputs for min, opt, max shapes of the graph inputs
submodule_inputs = partitioning.construct_submodule_inputs(submodule)
submodule_inputs = partitioning.construct_submodule_inputs(
submodule, user_symbol_bounds=user_symbol_bounds
)

assert submodule_inputs is not None

Expand Down
38 changes: 32 additions & 6 deletions py/torch_tensorrt/dynamo/partitioning/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import logging
from typing import Any, Dict, Optional, Sequence, Set, Tuple

import sympy
import torch
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily

from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo.utils import (
COMPLEX_TO_REAL_DTYPE,
Expand All @@ -20,11 +20,14 @@ def construct_dynamic_input(
input_dtype: torch.dtype,
name: str = "",
is_shape_tensor: bool = False,
user_symbol_bounds: Optional[Dict[sympy.Symbol, Tuple[int, int]]] = None,
) -> Input:
"""
Constructs a torch_tensorrt.Input based on a symbolic input
Args:
input_shape: A symbolic shape / regular shape of a tensor (which can have a mix of SymInt nodes and static values)
user_symbol_bounds: Optional ``{sym: (min, max)}`` map; forwarded to
:func:`extract_var_range_info` to fill unbounded exporter uppers.
Returns:
A dynamic shaped torch_tensorrt.Input which has the properties of the symbolic shaped input.
"""
Expand All @@ -33,7 +36,9 @@ def construct_dynamic_input(
max_shape = []
for d, dim in enumerate(input_shape):
if isinstance(dim, torch.SymInt):
min_max_opt = extract_var_range_info(dim)
min_max_opt = extract_var_range_info(
dim, user_symbol_bounds=user_symbol_bounds
)
unwrapped_min_max_opt: Dict[str, int] = {}
if "min" not in min_max_opt or min_max_opt["min"] is None:
logger.warning(
Expand Down Expand Up @@ -85,9 +90,12 @@ def get_input(
dtype: torch.dtype,
name: str = "",
is_shape_tensor: bool = False,
user_symbol_bounds: Optional[Dict[sympy.Symbol, Tuple[int, int]]] = None,
) -> Input:
"""
Based on type of dimensions in the input_shape, construct regular or dynamic shaped inputs
Based on type of dimensions in the input_shape, construct regular or dynamic shaped inputs.

``user_symbol_bounds`` is forwarded to :func:`construct_dynamic_input`.
"""
if dtype in COMPLEX_TO_REAL_DTYPE:
real_dtype = COMPLEX_TO_REAL_DTYPE[dtype]
Expand All @@ -106,19 +114,25 @@ def get_input(
dtype,
name=name,
is_shape_tensor=is_shape_tensor,
user_symbol_bounds=user_symbol_bounds,
)
else:
return Input(
shape=input_shape, dtype=dtype, name=name, is_shape_tensor=is_shape_tensor
)


def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
def construct_submodule_inputs(
module: torch.fx.GraphModule,
user_symbol_bounds: Optional[Dict[sympy.Symbol, Tuple[int, int]]] = None,
) -> Sequence[Input]:
"""
Construct torch_tensorrt Inputs based on the module inputs.
The module inputs will have meta data which has the shape and dtype info
Args:
module: Input FX GraphModule
user_symbol_bounds: Optional ``{sym: (min, max)}`` map; forwarded to
:func:`get_input` to fill unbounded exporter uppers.
Returns:
Sequence of torch_tensorrt.Input's representing inputs to given module
"""
Expand All @@ -134,7 +148,12 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
if isinstance(input_meta, (FakeTensor, torch.Tensor)):
input_shape = input_meta.size()
torchtrt_inputs.append(
get_input(input_shape, input_meta.dtype, name=input.name)
get_input(
input_shape,
input_meta.dtype,
name=input.name,
user_symbol_bounds=user_symbol_bounds,
)
)
elif isinstance(input_meta, torch.SymInt):
# Assuming sym_integers | shape inputs always have torch.int64 dtype
Expand All @@ -144,6 +163,7 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
torch.int64,
name=input.name,
is_shape_tensor=True,
user_symbol_bounds=user_symbol_bounds,
)
)
elif isinstance(input_meta, torch.SymFloat):
Expand All @@ -153,6 +173,7 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
torch.float32,
name=input.name,
is_shape_tensor=False, # Only SymInt inputs are treated as shape tensors
user_symbol_bounds=user_symbol_bounds,
)
)
else:
Expand All @@ -164,7 +185,12 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
input_meta = input.meta["tensor_meta"]
input_shape = input_meta.shape
torchtrt_inputs.append(
get_input(input_shape, input_meta.dtype, name=input.name)
get_input(
input_shape,
input_meta.dtype,
name=input.name,
user_symbol_bounds=user_symbol_bounds,
)
)
else:
raise AssertionError(
Expand Down
50 changes: 43 additions & 7 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,9 +406,16 @@ def contains_sym_int(tensor: torch.Tensor) -> bool:
return any(isinstance(dim, torch.SymInt) for dim in tensor)


def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, Optional[int]]:
"""
This function returns the min, max, opt values of a symbolic integer.
def extract_var_range_info(
symbolic_integer: torch.SymInt,
user_symbol_bounds: Optional[Dict[sympy.Symbol, Tuple[int, int]]] = None,
) -> Dict[str, Optional[int]]:
"""Return ``{min, max, opt}`` for a symbolic integer.

``user_symbol_bounds`` (read-only ``{sym: (min, max)}``) is consulted only
when the exporter's upper is unbounded; finite exporter bounds always win.
The lower is intersected with the exporter's so the 0/1 specialization
survives even if the user passes ``min_shape=0``.
"""
node = symbolic_integer.node
expr = node.expr
Expand All @@ -435,13 +442,42 @@ def extract_var_range_info(symbolic_integer: torch.SymInt) -> Dict[str, Optional
or expr.xreplace(var_to_val_map)
)
assert var_range, var_val
min_val, max_val = (
int(var_range.lower),
int(var_range.upper) if var_range.upper != int_oo else None,
)

# ``var_to_range`` returns ``int_oo`` for unbounded; ``bound_sympy`` (used
# for composite exprs like ``s0+s1``) returns ``sympy.oo`` instead. They
# are distinct objects -- check both, else ``int(sympy.oo)`` raises.
def _bound_to_int_or_none(value: Any) -> Optional[int]:
if value is int_oo or value is -int_oo:
return None
if value == sympy.oo or value == -sympy.oo:
return None
try:
return int(value)
except (TypeError, OverflowError, AttributeError):
return None

min_val_opt = _bound_to_int_or_none(var_range.lower)
max_val = _bound_to_int_or_none(var_range.upper)
# Unbounded lower shouldn't happen for tensor dims; fall back to 1.
min_val = min_val_opt if min_val_opt is not None else 1

# Torchdynamo 0/1 specialization outlier
min_val = 1 if min_val == 2 else min_val

# Apply user bounds whenever present. ``_build_user_symbol_bounds`` already
# rejects user ranges that exceed the exporter, so the only cases reaching
# here are: Dim.DYNAMIC (max_val is None), strict subset, or the 1->2
# specialization. Clamp defensively in case validation was skipped (no
# ShapeEnv access path).
if (
user_symbol_bounds
and isinstance(expr, sympy.Symbol)
and expr in user_symbol_bounds
):
user_min, user_max = user_symbol_bounds[expr]
min_val = max(min_val, int(user_min))
max_val = int(user_max) if max_val is None else min(max_val, int(user_max))

min_max_opt: Dict[str, Optional[int]] = {}
min_max_opt["min"] = min_val
min_max_opt["max"] = max_val
Expand Down
Loading
Loading