Skip to content

Commit f85e5c9

Browse files
[Torch Dialect] support aten.isneginf, aten.isposinf, aten.nan_to_num (#2743)
1 parent f78ec78 commit f85e5c9

File tree

8 files changed

+335
-0
lines changed

8 files changed

+335
-0
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8543,6 +8543,52 @@ def Torch_AtenIsinfOp : Torch_Op<"aten.isinf", [
85438543
}];
85448544
}
85458545

8546+
def Torch_AtenIsneginfOp : Torch_Op<"aten.isneginf", [
8547+
AllowsTypeRefinement,
8548+
HasValueSemantics,
8549+
ReadOnly
8550+
]> {
8551+
let summary = "Generated op for `aten::isneginf : (Tensor) -> (Tensor)`";
8552+
let arguments = (ins
8553+
AnyTorchTensorType:$self
8554+
);
8555+
let results = (outs
8556+
AnyTorchTensorType:$result
8557+
);
8558+
let hasCustomAssemblyFormat = 1;
8559+
let extraClassDefinition = [{
8560+
ParseResult AtenIsneginfOp::parse(OpAsmParser &parser, OperationState &result) {
8561+
return parseDefaultTorchOp(parser, result, 1, 1);
8562+
}
8563+
void AtenIsneginfOp::print(OpAsmPrinter &printer) {
8564+
printDefaultTorchOp(printer, *this, 1, 1);
8565+
}
8566+
}];
8567+
}
8568+
8569+
def Torch_AtenIsposinfOp : Torch_Op<"aten.isposinf", [
8570+
AllowsTypeRefinement,
8571+
HasValueSemantics,
8572+
ReadOnly
8573+
]> {
8574+
let summary = "Generated op for `aten::isposinf : (Tensor) -> (Tensor)`";
8575+
let arguments = (ins
8576+
AnyTorchTensorType:$self
8577+
);
8578+
let results = (outs
8579+
AnyTorchTensorType:$result
8580+
);
8581+
let hasCustomAssemblyFormat = 1;
8582+
let extraClassDefinition = [{
8583+
ParseResult AtenIsposinfOp::parse(OpAsmParser &parser, OperationState &result) {
8584+
return parseDefaultTorchOp(parser, result, 1, 1);
8585+
}
8586+
void AtenIsposinfOp::print(OpAsmPrinter &printer) {
8587+
printDefaultTorchOp(printer, *this, 1, 1);
8588+
}
8589+
}];
8590+
}
8591+
85468592
def Torch_AtenAllOp : Torch_Op<"aten.all", [
85478593
AllowsTypeRefinement,
85488594
HasValueSemantics,
@@ -10473,6 +10519,32 @@ def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [
1047310519
}];
1047410520
}
1047510521

10522+
def Torch_AtenNanToNumOp : Torch_Op<"aten.nan_to_num", [
10523+
AllowsTypeRefinement,
10524+
HasValueSemantics,
10525+
ReadOnly
10526+
]> {
10527+
let summary = "Generated op for `aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)`";
10528+
let arguments = (ins
10529+
AnyTorchTensorType:$self,
10530+
AnyTorchOptionalFloatType:$nan,
10531+
AnyTorchOptionalFloatType:$posinf,
10532+
AnyTorchOptionalFloatType:$neginf
10533+
);
10534+
let results = (outs
10535+
AnyTorchTensorType:$result
10536+
);
10537+
let hasCustomAssemblyFormat = 1;
10538+
let extraClassDefinition = [{
10539+
ParseResult AtenNanToNumOp::parse(OpAsmParser &parser, OperationState &result) {
10540+
return parseDefaultTorchOp(parser, result, 4, 1);
10541+
}
10542+
void AtenNanToNumOp::print(OpAsmPrinter &printer) {
10543+
printDefaultTorchOp(printer, *this, 4, 1);
10544+
}
10545+
}];
10546+
}
10547+
1047610548
def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [
1047710549
AllowsTypeRefinement,
1047810550
ReadOnly

lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6702,6 +6702,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
67026702
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
67036703
" return %0 : !torch.list<int>\n"
67046704
" }\n"
6705+
" func.func @\"__torch_mlir_shape_fn.aten.isneginf\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
6706+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
6707+
" return %0 : !torch.list<int>\n"
6708+
" }\n"
6709+
" func.func @\"__torch_mlir_shape_fn.aten.isposinf\"(%arg0: !torch.list<int>) -> !torch.list<int> {\n"
6710+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
6711+
" return %0 : !torch.list<int>\n"
6712+
" }\n"
67056713
" func.func @\"__torch_mlir_shape_fn.aten.ne.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
67066714
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
67076715
" return %0 : !torch.list<int>\n"
@@ -7874,6 +7882,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
78747882
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
78757883
" return %0 : !torch.list<int>\n"
78767884
" }\n"
7885+
" func.func @\"__torch_mlir_shape_fn.aten.nan_to_num\"(%arg0: !torch.list<int>, %arg1: !torch.optional<float>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.list<int> {\n"
7886+
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
7887+
" return %0 : !torch.list<int>\n"
7888+
" }\n"
78777889
" func.func @\"__torch_mlir_shape_fn.aten.lerp.Tensor\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
78787890
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
78797891
" %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
@@ -9739,6 +9751,52 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
97399751
" %int11 = torch.constant.int 11\n"
97409752
" return %int11 : !torch.int\n"
97419753
" }\n"
9754+
" func.func @\"__torch_mlir_dtype_fn.aten.isneginf\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
9755+
" %int11 = torch.constant.int 11\n"
9756+
" %none = torch.constant.none\n"
9757+
" %str = torch.constant.str \"AssertionError: \"\n"
9758+
" %false = torch.constant.bool false\n"
9759+
" %int9 = torch.constant.int 9\n"
9760+
" %int10 = torch.constant.int 10\n"
9761+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
9762+
" %1 = torch.aten.ne.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n"
9763+
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
9764+
" %3 = torch.aten.ne.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n"
9765+
" torch.prim.If.yield %3 : !torch.bool\n"
9766+
" } else {\n"
9767+
" torch.prim.If.yield %false : !torch.bool\n"
9768+
" }\n"
9769+
" torch.prim.If %2 -> () {\n"
9770+
" torch.prim.If.yield\n"
9771+
" } else {\n"
9772+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
9773+
" torch.prim.If.yield\n"
9774+
" }\n"
9775+
" return %int11 : !torch.int\n"
9776+
" }\n"
9777+
" func.func @\"__torch_mlir_dtype_fn.aten.isposinf\"(%arg0: !torch.tuple<int, int>) -> !torch.int {\n"
9778+
" %int11 = torch.constant.int 11\n"
9779+
" %none = torch.constant.none\n"
9780+
" %str = torch.constant.str \"AssertionError: \"\n"
9781+
" %false = torch.constant.bool false\n"
9782+
" %int9 = torch.constant.int 9\n"
9783+
" %int10 = torch.constant.int 10\n"
9784+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
9785+
" %1 = torch.aten.ne.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n"
9786+
" %2 = torch.prim.If %1 -> (!torch.bool) {\n"
9787+
" %3 = torch.aten.ne.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n"
9788+
" torch.prim.If.yield %3 : !torch.bool\n"
9789+
" } else {\n"
9790+
" torch.prim.If.yield %false : !torch.bool\n"
9791+
" }\n"
9792+
" torch.prim.If %2 -> () {\n"
9793+
" torch.prim.If.yield\n"
9794+
" } else {\n"
9795+
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
9796+
" torch.prim.If.yield\n"
9797+
" }\n"
9798+
" return %int11 : !torch.int\n"
9799+
" }\n"
97429800
" func.func @\"__torch_mlir_dtype_fn.aten.ne.Tensor\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
97439801
" %int11 = torch.constant.int 11\n"
97449802
" return %int11 : !torch.int\n"
@@ -10742,6 +10800,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
1074210800
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
1074310801
" return %4 : !torch.int\n"
1074410802
" }\n"
10803+
" func.func @\"__torch_mlir_dtype_fn.aten.nan_to_num\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<float>, %arg2: !torch.optional<float>, %arg3: !torch.optional<float>) -> !torch.int {\n"
10804+
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
10805+
" return %0#1 : !torch.int\n"
10806+
" }\n"
1074510807
" func.func @\"__torch_mlir_dtype_fn.aten.nll_loss_forward\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.optional<tuple<int, int>>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<int, int> {\n"
1074610808
" %none = torch.constant.none\n"
1074710809
" %str = torch.constant.str \"AssertionError: \"\n"

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -932,6 +932,40 @@ class DecomposeAtenIsinfOp : public OpRewritePattern<AtenIsinfOp> {
932932
};
933933
} // namespace
934934

935+
namespace {
936+
class DecomposeAtenIsneginfOp : public OpRewritePattern<AtenIsneginfOp> {
937+
using OpRewritePattern::OpRewritePattern;
938+
LogicalResult matchAndRewrite(AtenIsneginfOp op,
939+
PatternRewriter &rewriter) const override {
940+
mlir::FloatType f64Type = rewriter.getF64Type();
941+
Value inf = rewriter.create<ConstantFloatOp>(
942+
op.getLoc(),
943+
rewriter.getFloatAttr(
944+
f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true)));
945+
rewriter.replaceOpWithNewOp<AtenEqScalarOp>(op, op.getType(), op.getSelf(),
946+
inf);
947+
return success();
948+
}
949+
};
950+
} // namespace
951+
952+
namespace {
953+
class DecomposeAtenIsposinfOp : public OpRewritePattern<AtenIsposinfOp> {
954+
using OpRewritePattern::OpRewritePattern;
955+
LogicalResult matchAndRewrite(AtenIsposinfOp op,
956+
PatternRewriter &rewriter) const override {
957+
mlir::FloatType f64Type = rewriter.getF64Type();
958+
Value inf = rewriter.create<ConstantFloatOp>(
959+
op.getLoc(),
960+
rewriter.getFloatAttr(f64Type,
961+
APFloat::getInf(f64Type.getFloatSemantics())));
962+
rewriter.replaceOpWithNewOp<AtenEqScalarOp>(op, op.getType(), op.getSelf(),
963+
inf);
964+
return success();
965+
}
966+
};
967+
} // namespace
968+
935969
namespace {
936970
class DecomposeAtenReshapeOp : public OpRewritePattern<AtenReshapeOp> {
937971
public:
@@ -2471,6 +2505,49 @@ class DecomposeAtenWhereScalarSelfOp
24712505
};
24722506
} // namespace
24732507

2508+
namespace {
2509+
class DecomposeAtenNanToNumOp : public OpRewritePattern<AtenNanToNumOp> {
2510+
public:
2511+
using OpRewritePattern::OpRewritePattern;
2512+
LogicalResult matchAndRewrite(AtenNanToNumOp op,
2513+
PatternRewriter &rewriter) const override {
2514+
Location loc = op.getLoc();
2515+
mlir::FloatType f64Type = rewriter.getF64Type();
2516+
Value nan = op.getNan();
2517+
Value posinf = op.getPosinf();
2518+
Value neginf = op.getNeginf();
2519+
auto baseType =
2520+
ValueTensorType::getWithLeastStaticInformation(op.getContext());
2521+
if (dyn_cast_or_null<ConstantNoneOp>(nan.getDefiningOp()))
2522+
nan = rewriter.create<ConstantFloatOp>(
2523+
loc, rewriter.getFloatAttr(
2524+
f64Type, APFloat::getZero(f64Type.getFloatSemantics())));
2525+
if (dyn_cast_or_null<ConstantNoneOp>(posinf.getDefiningOp()))
2526+
posinf = rewriter.create<ConstantFloatOp>(
2527+
loc, rewriter.getFloatAttr(
2528+
f64Type, APFloat::getInf(f64Type.getFloatSemantics())));
2529+
if (dyn_cast_or_null<ConstantNoneOp>(neginf.getDefiningOp()))
2530+
neginf = rewriter.create<ConstantFloatOp>(
2531+
loc,
2532+
rewriter.getFloatAttr(
2533+
f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true)));
2534+
Value isNan =
2535+
rewriter.create<Torch::AtenIsnanOp>(loc, baseType, op.getSelf());
2536+
Value where = rewriter.create<Torch::AtenWhereScalarSelfOp>(
2537+
loc, baseType, isNan, nan, op.getSelf());
2538+
Value isposinf =
2539+
rewriter.create<Torch::AtenIsposinfOp>(loc, baseType, where);
2540+
where = rewriter.create<Torch::AtenWhereScalarSelfOp>(
2541+
loc, baseType, isposinf, posinf, where);
2542+
Value isneginf =
2543+
rewriter.create<Torch::AtenIsneginfOp>(loc, baseType, where);
2544+
rewriter.replaceOpWithNewOp<Torch::AtenWhereScalarSelfOp>(
2545+
op, op.getType(), isneginf, neginf, where);
2546+
return success();
2547+
}
2548+
};
2549+
} // namespace
2550+
24742551
// Decompose aten.masked_fill.Scalar into aten.where.self op.
24752552
namespace {
24762553
class DecomposeAtenMaskedFillScalarOp
@@ -6393,6 +6470,7 @@ class DecomposeComplexOpsPass
63936470
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOp>(patterns);
63946471
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOtherOp>(patterns);
63956472
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);
6473+
addPatternIfTargetOpIsIllegal<DecomposeAtenNanToNumOp>(patterns);
63966474
addPatternIfTargetOpIsIllegal<DecomposeAtenMaskedFillScalarOp>(patterns);
63976475
addPatternIfTargetOpIsIllegal<DecomposeAtenSizeOp>(patterns);
63986476
addPatternIfTargetOpIsIllegal<DecomposeAtenReshapeOp>(patterns);
@@ -6448,6 +6526,8 @@ class DecomposeComplexOpsPass
64486526
addPatternIfTargetOpIsIllegal<DecomposeAtenEyeMOp>(patterns);
64496527
addPatternIfTargetOpIsIllegal<DecomposeAtenIsnanOp>(patterns);
64506528
addPatternIfTargetOpIsIllegal<DecomposeAtenIsinfOp>(patterns);
6529+
addPatternIfTargetOpIsIllegal<DecomposeAtenIsneginfOp>(patterns);
6530+
addPatternIfTargetOpIsIllegal<DecomposeAtenIsposinfOp>(patterns);
64516531
addPatternIfTargetOpIsIllegal<DecomposeAtenRandLikeOp>(patterns);
64526532
addPatternIfTargetOpIsIllegal<DecomposeAtenHardsigmoidOp>(patterns);
64536533
addPatternIfTargetOpIsIllegal<DecomposeAtenRelu6Op>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,8 +431,11 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
431431
target.addIllegalOp<AtenZeroOp>();
432432
target.addIllegalOp<AtenEyeOp>();
433433
target.addIllegalOp<AtenEyeMOp>();
434+
target.addIllegalOp<AtenNanToNumOp>();
434435
target.addIllegalOp<AtenIsnanOp>();
435436
target.addIllegalOp<AtenIsinfOp>();
437+
target.addIllegalOp<AtenIsneginfOp>();
438+
target.addIllegalOp<AtenIsposinfOp>();
436439
target.addIllegalOp<AtenRandLikeOp>();
437440
target.addIllegalOp<AtenHardsigmoidOp>();
438441
target.addIllegalOp<AtenRelu6Op>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@
473473
"ElementwiseAtenWhereSelfModule_basic",
474474
"ElementwiseWhereScalarOtherStaticModule_basic",
475475
"ElementwiseWhereScalarSelfStaticModule_basic",
476+
"ElementwiseNanToNumModule_Basic",
476477
"ElementwiseBitwiseAndStaticShapeModule_basic",
477478
"ElementwiseBitwiseNotInt64Module_basic",
478479
"ElementwiseBitwiseNotInt32Module_basic",
@@ -1039,6 +1040,8 @@
10391040
"ElementwiseAddScalar_TensorLiteralInt32_Module_basic",
10401041
"ElementwiseAtenDivIntScalarModule_basic",
10411042
"ElementwiseAtenIsinfOpModule_basic",
1043+
"ElementwiseAtenIsneginfOpModule_basic",
1044+
"ElementwiseAtenIsposinfOpModule_basic",
10421045
"ElementwiseAtenLogicalOrOpBrodcastModule_basic",
10431046
"ElementwiseAtenLogicalOrOpDiffArgs1Module_basic",
10441047
"ElementwiseAtenLogicalOrOpDiffArgs2Module_basic",
@@ -1090,6 +1093,8 @@
10901093
"ElementwiseGtIntTensorModule_basic",
10911094
"ElementwiseGtMixed2ScalarModule_basic",
10921095
"ElementwiseIsinfModule_basic",
1096+
"ElementwiseAtenIsneginfOpModule_basic",
1097+
"ElementwiseAtenIsposinfOpModule_basic",
10931098
"ElementwiseIsnanModule_basic",
10941099
"ElementwiseLeFloatTensorModule_basic",
10951100
"ElementwiseLeIntTensorModule_basic",
@@ -1146,6 +1151,7 @@
11461151
"ElementwiseUnaryModule_basic",
11471152
"ElementwiseUnsqueezeBroadcastModule_basic",
11481153
"ElementwiseWhereScalarModule_basic",
1154+
"ElementwiseNanToNumModule_Basic",
11491155
"EmbeddingModule1DIndices_basic",
11501156
"EmbeddingModuleI32Static_basic",
11511157
"FlattenRank0Module_basic",
@@ -1511,6 +1517,7 @@
15111517
"ElementwiseBitwiseAndScalarInt64Module_basic",
15121518
"ElementwiseBitwiseAndScalarInt32Module_basic",
15131519
"ElementwiseBitwiseAndScalarInt8Module_basic",
1520+
"ElementwiseNanToNumModule_Basic",
15141521
"ElementwiseQuantizePerTensorModule_basic",
15151522
"ElementwiseDequantizePerTensorModule_basic"
15161523
}

projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,12 @@ def aten〇isnan〡shape(self: List[int]) -> List[int]:
341341
def aten〇isinf〡shape(self: List[int]) -> List[int]:
342342
return upstream_shape_functions.unary(self)
343343

344+
def aten〇isneginf〡shape(self: List[int]) -> List[int]:
345+
return upstream_shape_functions.unary(self)
346+
347+
def aten〇isposinf〡shape(self: List[int]) -> List[int]:
348+
return upstream_shape_functions.unary(self)
349+
344350
def aten〇ne〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
345351
return upstream_shape_functions.broadcast(self, other)
346352

@@ -1062,6 +1068,9 @@ def aten〇where〇ScalarOther〡shape(condition: List[int], self: List[int], ot
10621068
def aten〇where〇ScalarSelf〡shape(condition: List[int], self: float, other: List[int]) -> List[int]:
10631069
return upstream_shape_functions.broadcast(condition, other)
10641070

1071+
def aten〇nan_to_num〡shape(self: List[int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> List[int]:
1072+
return upstream_shape_functions.unary(self)
1073+
10651074
def aten〇lerp〇Tensor〡shape(self: List[int], end: List[int], weight: List[int]) -> List[int]:
10661075
return upstream_shape_functions.broadcast(self, upstream_shape_functions.broadcast(end, weight))
10671076

@@ -2529,6 +2538,20 @@ def aten〇isnan〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
25292538
def aten〇isinf〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
25302539
return torch.bool
25312540

2541+
@check_dtype_function(
2542+
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex128, torch.complex64}))
2543+
def aten〇isneginf〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
2544+
self_rank, self_dtype = self_rank_dtype
2545+
assert self_dtype != torch.complex128 and self_dtype != torch.complex64
2546+
return torch.bool
2547+
2548+
@check_dtype_function(
2549+
_check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex128, torch.complex64}))
2550+
def aten〇isposinf〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
2551+
self_rank, self_dtype = self_rank_dtype
2552+
assert self_dtype != torch.complex128 and self_dtype != torch.complex64
2553+
return torch.bool
2554+
25322555
@check_dtype_function(_check_two_tensor_op())
25332556
def aten〇ne〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int:
25342557
return torch.bool
@@ -3260,6 +3283,12 @@ def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], sel
32603283
dtypes = [get_dtype_of_scalar(self), other_dtype]
32613284
return promote_dtypes(ranks, dtypes)
32623285

3286+
@check_dtype_function(
3287+
_check_tensors_with_the_same_dtype(num_of_tensors=1))
3288+
def aten〇nan_to_num〡dtype(self_rank_dtype: Tuple[int, int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> int:
3289+
self_rank, self_dtype = self_rank_dtype
3290+
return self_dtype
3291+
32633292
@check_dtype_function(
32643293
[Invocation(TensorOfShape(2, 3, dtype=torch.float32), TensorOfShape(2, dtype=torch.int64),
32653294
TensorOfShape(3, dtype=torch.float32), reduction=0, ignore_index=0),

0 commit comments

Comments
 (0)