|
9 | 9 | from tensorrt import ITensor as TRTTensor |
10 | 10 | from torch.fx.node import Argument, Node, Target |
11 | 11 | from torch_tensorrt._features import needs_not_tensorrt_rtx |
12 | | -from torch_tensorrt._utils import is_tensorrt_version_supported |
| 12 | +from torch_tensorrt._utils import is_tensorrt_version_supported, is_thor |
13 | 13 | from torch_tensorrt.dynamo._settings import CompilationSettings |
14 | 14 | from torch_tensorrt.dynamo._SourceIR import SourceIR |
15 | 15 | from torch_tensorrt.dynamo.conversion import impl |
@@ -424,9 +424,24 @@ def index_dtype_validator( |
424 | 424 | return True |
425 | 425 |
|
426 | 426 |
|
| 427 | +def index_nonbool_validator( |
| 428 | + node: Node, settings: Optional[CompilationSettings] = None |
| 429 | +) -> bool: |
| 430 | + # for thor, we don't support boolean indices |
| 431 | + if is_thor(): |
| 432 | + index = node.args[1] |
| 433 | + for ind in index: |
| 434 | + if ind is not None: |
| 435 | + val = ind.meta.get("val") |
| 436 | + if val is not None and val.dtype == torch.bool: |
| 437 | + return False |
| 438 | + return True |
| 439 | + |
| 440 | + |
427 | 441 | @dynamo_tensorrt_converter( |
428 | 442 | torch.ops.aten.index.Tensor, |
429 | | - capability_validator=index_dtype_validator, |
| 443 | + capability_validator=lambda node, settings: index_dtype_validator(node, settings) |
| 444 | + and index_nonbool_validator(node, settings), |
430 | 445 | supports_dynamic_shapes=True, |
431 | 446 | requires_output_allocator=True, |
432 | 447 | ) |
@@ -3601,10 +3616,18 @@ def aten_ops_full( |
3601 | 3616 | ) |
3602 | 3617 |
|
3603 | 3618 |
|
| 3619 | +def nonzero_validator( |
| 3620 | + node: Node, settings: Optional[CompilationSettings] = None |
| 3621 | +) -> bool: |
| 3622 | + return not is_thor() |
| 3623 | + |
| 3624 | + |
3604 | 3625 | # currently nonzero is not supported for tensorrt_rtx |
3605 | 3626 | # TODO: lan to add the nonzero support once tensorrt_rtx team has added the support |
| 3627 | +# TODO: apbose to remove the capability validator once thor bug resolve in NGC |
3606 | 3628 | @dynamo_tensorrt_converter( |
3607 | 3629 | torch.ops.aten.nonzero.default, |
| 3630 | + capability_validator=nonzero_validator, |
3608 | 3631 | supports_dynamic_shapes=True, |
3609 | 3632 | requires_output_allocator=True, |
3610 | 3633 | ) |
|
0 commit comments