Skip to content

Commit fb41a82

Browse files
authored
DLFW 25.11 changes (#3889)
1 parent 3a9bc3b commit fb41a82

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tensorrt import ITensor as TRTTensor
1010
from torch.fx.node import Argument, Node, Target
1111
from torch_tensorrt._features import needs_not_tensorrt_rtx
12-
from torch_tensorrt._utils import is_tensorrt_version_supported
12+
from torch_tensorrt._utils import is_tensorrt_version_supported, is_thor
1313
from torch_tensorrt.dynamo._settings import CompilationSettings
1414
from torch_tensorrt.dynamo._SourceIR import SourceIR
1515
from torch_tensorrt.dynamo.conversion import impl
@@ -424,9 +424,24 @@ def index_dtype_validator(
424424
return True
425425

426426

427+
def index_nonbool_validator(
428+
node: Node, settings: Optional[CompilationSettings] = None
429+
) -> bool:
430+
# for thor, we don't support boolean indices
431+
if is_thor():
432+
index = node.args[1]
433+
for ind in index:
434+
if ind is not None:
435+
val = ind.meta.get("val")
436+
if val is not None and val.dtype == torch.bool:
437+
return False
438+
return True
439+
440+
427441
@dynamo_tensorrt_converter(
428442
torch.ops.aten.index.Tensor,
429-
capability_validator=index_dtype_validator,
443+
capability_validator=lambda node, settings: index_dtype_validator(node, settings)
444+
and index_nonbool_validator(node, settings),
430445
supports_dynamic_shapes=True,
431446
requires_output_allocator=True,
432447
)
@@ -3601,10 +3616,18 @@ def aten_ops_full(
36013616
)
36023617

36033618

3619+
def nonzero_validator(
3620+
node: Node, settings: Optional[CompilationSettings] = None
3621+
) -> bool:
3622+
return not is_thor()
3623+
3624+
36043625
# currently nonzero is not supported for tensorrt_rtx
36053626
# TODO: lan to add the nonzero support once tensorrt_rtx team has added the support
3627+
# TODO: apbose to remove the capability validator once thor bug resolve in NGC
36063628
@dynamo_tensorrt_converter(
36073629
torch.ops.aten.nonzero.default,
3630+
capability_validator=nonzero_validator,
36083631
supports_dynamic_shapes=True,
36093632
requires_output_allocator=True,
36103633
)

tests/core/conversion/converters/test_reduce.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,14 @@ TEST(Converters, ATenAnyDimNegIndexConvertsCorrectly) {
345345
%3 : bool = prim::Constant[value=1]()
346346
%5 : Tensor = aten::any(%0, %1, %3)
347347
return (%5))IR";
348-
auto in = at::randint(-2, 2, {2, 32}, at::kCUDA);
348+
std::vector<int> data(64, 0);
349+
for (int i = 0; i < 64; ++i) {
350+
if (i % 7 == 0)
351+
data[i] = 1; // some positives
352+
if (i % 13 == 0)
353+
data[i] = -1; // some negatives
354+
}
355+
auto in = at::tensor(data, at::TensorOptions().dtype(at::kInt).device(at::kCUDA)).reshape({2, 32}); // shape [2, 32]
349356
test_body(graph, in);
350357
}
351358

0 commit comments

Comments
 (0)