diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index c3dbc095a745..70ff037e9c3a 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1111,6 +1111,11 @@ class ConvertAtenMultipleDimsReductionOp for (int64_t i = 0; i < inputRank; i++) reduceDims.push_back(i); } + // PyTorch treats an explicit empty list the same as "reduce all dims". + if (reduceDims.empty()) { + for (int64_t i = 0; i < inputRank; i++) + reduceDims.push_back(i); + } int64_t N = reduceDims.size(); for (unsigned i = 0; i < N; i++) { diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 036f0f2e5110..1378e2a902ef 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -785,13 +785,23 @@ std::optional convertReduceOpCommon( // Optionally squeeze out the reduced axes. if (!keep_dims) { + auto squeezedType = + RankedTensorType::get(output_shape, reduce_element_type); auto reshape_op = CreateOpAndInfer( - rewriter, op->getLoc(), output_type, val, + rewriter, op->getLoc(), squeezedType, val, tosa::getTosaConstShape(rewriter, op->getLoc(), output_shape)); val = reshape_op.getResult(); } } + // Ensure the result element type matches the expected output type. + if (val.getType() != output_type) { + auto casted = tosa::tosaCastTensorToType(rewriter, val, output_type); + if (!casted) + return std::nullopt; + val = casted.value(); + } + return val; } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 4c8318570c7b..d9eb6c05a957 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -3431,6 +3431,8 @@ "ElementwiseClampMinModule_bfloat16", "ElementwiseClampModule_bfloat16", "ElementwiseReluModule_bfloat16", + # torch.onnx.errors.SymbolicValueError: Cannot determine scalar type for this '' + "ReduceSumEmptyDimListInt8ToInt32Module_basic", } if torch_version_for_comparison() < version.parse("2.3.0.dev"): @@ -3822,7 +3824,6 @@ "MaxPool3dWithIndicesNonDefaultStrideModule_basic", "MaxPool3dWithIndicesStaticModule_basic", "MaxPool3dSingleIntTupleDilationModule_basic", - "MeanDimEmptyDimModule_basic", "MlGroupNormManualModule_basic", "MlGroupNormModule_basic", "MlLayerNormManualModule_basic", @@ -3877,7 +3878,6 @@ "ReduceL3NormKeepDimComplexModule_basic", "ReduceMaxAlongDimUnsignedInt_basic", "ReduceMinAlongDimUnsignedInt_basic", - "ReduceSumDimIntListEmptyDimModule_basic", "RollModule_basic", "ScalarConstantTupleModule_basic", "ScalarImplicitFloatModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 0eb0545e7f11..2e4ba9c4ccfc 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -58,6 +58,52 @@ def ReduceSumDtypeFloatModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceSumEmptyDimListInt8ToInt32Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int8, True), + ] + ) + def forward(self, a): + return torch.sum(a, dim=[], dtype=torch.int32) + + +@register_test_case(module_factory=lambda: ReduceSumEmptyDimListInt8ToInt32Module()) +def ReduceSumEmptyDimListInt8ToInt32Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, low=-16, high=16).to(torch.int8)) + + +# ============================================================================== + + +class ReduceSumEmptyDimListInt8Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.int8, True), + ] + ) + def forward(self, a): + return torch.sum(a, dim=[]) + + +@register_test_case(module_factory=lambda: ReduceSumEmptyDimListInt8Module()) +def ReduceSumEmptyDimListInt8Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, 5, low=-16, high=16).to(torch.int8)) + + +# ============================================================================== + + class ReduceSumElementTypeBoolModule(torch.nn.Module): def __init__(self): super().__init__() diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index bb04a9772f99..5b28c77fdf50 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -325,6 +325,53 @@ func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[3,4,5,6],f32>) -> ! // ----- +// CHECK-LABEL: func.func @test_reduce_sum_empty_dims$basic( +// CHECK-SAME: %[[INPUT_F32:.*]]: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[],f32> { +// CHECK: %[[INPUT_F32_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_F32]] : !torch.vtensor<[2,3,4],f32> -> tensor<2x3x4xf32> +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[EMPTY_DIMS:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[SUM_DIM0:.*]] = tosa.reduce_sum %[[INPUT_F32_TENSOR]] {axis = 0 : i32} : (tensor<2x3x4xf32>) -> tensor<1x3x4xf32> +// CHECK: %[[SUM_DIM1:.*]] = tosa.reduce_sum %[[SUM_DIM0]] {axis = 1 : i32} : (tensor<1x3x4xf32>) -> tensor<1x1x4xf32> +// CHECK: %[[SUM_DIM2:.*]] = tosa.reduce_sum %[[SUM_DIM1]] {axis = 2 : i32} : (tensor<1x1x4xf32>) -> tensor<1x1x1xf32> +// CHECK: %[[SCALAR_SHAPE:.*]] = tosa.const_shape +// CHECK: %[[RESHAPED_SCALAR:.*]] = tosa.reshape %[[SUM_DIM2]], %[[SCALAR_SHAPE]] : (tensor<1x1x1xf32>, !tosa.shape<0>) -> tensor +// CHECK: %[[RESULT_F32:.*]] = torch_c.from_builtin_tensor %[[RESHAPED_SCALAR]] : tensor -> !torch.vtensor<[],f32> +// CHECK: return %[[RESULT_F32]] : !torch.vtensor<[],f32> +// CHECK: } +func.func @test_reduce_sum_empty_dims$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[],f32> { + %dtype_none = torch.constant.none + %keep_dims_false = torch.constant.bool false + %all_dims_list = torch.prim.ListConstruct : () -> !torch.list + %sum_all_dims = torch.aten.sum.dim_IntList %arg0, %all_dims_list, %keep_dims_false, %dtype_none : !torch.vtensor<[2,3,4],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[],f32> + return %sum_all_dims : !torch.vtensor<[],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_sum_empty_dims_i8_to_i32$basic( +// CHECK-SAME: %[[INPUT_I8:.*]]: !torch.vtensor<[2,3,4],si8>) -> !torch.vtensor<[],si32> { +// CHECK: %[[INPUT_I8_TENSOR:.*]] = torch_c.to_builtin_tensor %[[INPUT_I8]] : !torch.vtensor<[2,3,4],si8> -> tensor<2x3x4xi8> +// CHECK: %[[DTYPE_I32:.*]] = torch.constant.int 3 +// CHECK: %[[EMPTY_DIMS:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[CAST_INPUT_TO_I32:.*]] = tosa.cast %[[INPUT_I8_TENSOR]] : (tensor<2x3x4xi8>) -> tensor<2x3x4xi32> +// CHECK: %[[SUM_DIM0:.*]] = tosa.reduce_sum %[[CAST_INPUT_TO_I32]] {axis = 0 : i32} : (tensor<2x3x4xi32>) -> tensor<1x3x4xi32> +// CHECK: %[[SUM_DIM1:.*]] = tosa.reduce_sum %[[SUM_DIM0]] {axis = 1 : i32} : (tensor<1x3x4xi32>) -> tensor<1x1x4xi32> +// CHECK: %[[SUM_DIM2:.*]] = tosa.reduce_sum %[[SUM_DIM1]] {axis = 2 : i32} : (tensor<1x1x4xi32>) -> tensor<1x1x1xi32> +// CHECK: %[[SCALAR_SHAPE:.*]] = tosa.const_shape +// CHECK: %[[RESHAPED_SCALAR:.*]] = tosa.reshape %[[SUM_DIM2]], %[[SCALAR_SHAPE]] : (tensor<1x1x1xi32>, !tosa.shape<0>) -> tensor +// CHECK: %[[RESULT_I32:.*]] = torch_c.from_builtin_tensor %[[RESHAPED_SCALAR]] : tensor -> !torch.vtensor<[],si32> +// CHECK: return %[[RESULT_I32]] : !torch.vtensor<[],si32> +// CHECK: } +func.func @test_reduce_sum_empty_dims_i8_to_i32$basic(%arg0: !torch.vtensor<[2,3,4],si8>) -> !torch.vtensor<[],si32> { + %dtype_i32 = torch.constant.int 3 + %keep_dims_false = torch.constant.bool false + %all_dims_list = torch.prim.ListConstruct : () -> !torch.list + %sum_all_dims_to_i32 = torch.aten.sum.dim_IntList %arg0, %all_dims_list, %keep_dims_false, %dtype_i32 : !torch.vtensor<[2,3,4],si8>, !torch.list, !torch.bool, !torch.int -> !torch.vtensor<[],si32> + return %sum_all_dims_to_i32 : !torch.vtensor<[],si32> +} + +// ----- + // CHECK-LABEL: func.func @test_linalg_vector_norm$basic( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,151,64],f32>) -> !torch.vtensor<[3,151,1],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,151,64],f32> -> tensor<3x151x64xf32>