@@ -1062,3 +1062,116 @@ func.func @ints_constant() -> !torch.vtensor<[2], si64> attributes {torch.onnx_m
10621062 return %0 : !torch.vtensor <[2 ],si64 >
10631063}
10641064
1065+ // CHECK-LABEL: @test_flatten_4d_axis_2
1066+ func.func @test_flatten_4d_axis_2 (%arg0: !torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[6 ,20 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1067+ // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2
1068+ // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
1069+ // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
1070+ // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
1071+ // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1
1072+ // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32>
1073+ %0 = torch.operator " onnx.Flatten" (%arg0 ) {torch.onnx.axis = 2 : si64 } : (!torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[6 ,20 ],f32 >
1074+ return %0 : !torch.vtensor <[6 ,20 ],f32 >
1075+ }
1076+
1077+ // CHECK-LABEL: @test_flatten_4d_axis_0
1078+ func.func @test_flatten_4d_axis_0 (%arg0: !torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[1 ,120 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1079+ // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
1080+ // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
1081+ // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
1082+ // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
1083+ // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,120],f32>
1084+ %0 = torch.operator " onnx.Flatten" (%arg0 ) {torch.onnx.axis = 0 : si64 } : (!torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[1 ,120 ],f32 >
1085+ return %0 : !torch.vtensor <[1 ,120 ],f32 >
1086+ }
1087+
1088+ // CHECK-LABEL: @test_flatten_4d_axis_4
1089+ func.func @test_flatten_4d_axis_4 (%arg0: !torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[120 ,1 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1090+ // CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 4
1091+ // CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int -> !torch.vtensor
1092+ // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
1093+ // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 3
1094+ // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[120,1],f32>
1095+ %0 = torch.operator " onnx.Flatten" (%arg0 ) {torch.onnx.axis = 4 : si64 } : (!torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[120 ,1 ],f32 >
1096+ return %0 : !torch.vtensor <[120 ,1 ],f32 >
1097+ }
1098+
1099+ // CHECK-LABEL: @test_flatten_4d_axis_negative_2
1100+ func.func @test_flatten_4d_axis_negative_2 (%arg0: !torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[6 ,20 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1101+ // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2
1102+ // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
1103+ // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
1104+ // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
1105+ // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1
1106+ // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32>
1107+ %0 = torch.operator " onnx.Flatten" (%arg0 ) {torch.onnx.axis = -2 : si64 } : (!torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[6 ,20 ],f32 >
1108+ return %0 : !torch.vtensor <[6 ,20 ],f32 >
1109+ }
1110+
1111+ // CHECK-LABEL: @test_flatten_4d_axis_negative_1
1112+ func.func @test_flatten_4d_axis_negative_1 (%arg0: !torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[24 ,5 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1113+ // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 3
1114+ // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
1115+ // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
1116+ // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
1117+ // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 2
1118+ // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[24,5],f32>
1119+ %0 = torch.operator " onnx.Flatten" (%arg0 ) {torch.onnx.axis = -1 : si64 } : (!torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[24 ,5 ],f32 >
1120+ return %0 : !torch.vtensor <[24 ,5 ],f32 >
1121+ }
1122+
1123+ // CHECK-LABEL: @test_flatten_4d_axis_negative_4
1124+ func.func @test_flatten_4d_axis_negative_4 (%arg0: !torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[1 ,120 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1125+ // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
1126+ // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3
1127+ // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor
1128+ // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
1129+ // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,120],f32>
1130+ %0 = torch.operator " onnx.Flatten" (%arg0 ) {torch.onnx.axis = -4 : si64 } : (!torch.vtensor <[2 ,3 ,4 ,5 ],f32 >) -> !torch.vtensor <[1 ,120 ],f32 >
1131+ return %0 : !torch.vtensor <[1 ,120 ],f32 >
1132+ }
1133+
1134+ // CHECK-LABEL: @test_flatten_2d_axis_1
1135+ func.func @test_flatten_2d_axis_1 (%arg0: !torch.vtensor <[2 ,3 ],f32 >) -> !torch.vtensor <[2 ,3 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1136+ // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 1
1137+ // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 1
1138+ // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor
1139+ // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
1140+ // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0
1141+ // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32>
1142+ %0 = torch.operator " onnx.Flatten" (%arg0 ) {torch.onnx.axis = 1 : si64 } : (!torch.vtensor <[2 ,3 ],f32 >) -> !torch.vtensor <[2 ,3 ],f32 >
1143+ return %0 : !torch.vtensor <[2 ,3 ],f32 >
1144+ }
1145+
1146+ // CHECK-LABEL: @test_flatten_1d_axis_0
1147+ func.func @test_flatten_1d_axis_0 (%arg0: !torch.vtensor <[2 ],f32 >) -> !torch.vtensor <[1 ,2 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1148+ // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
1149+ // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0
1150+ // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor
1151+ // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
1152+ // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,2],f32>
1153+ %0 = torch.operator " onnx.Flatten" (%arg0 ) {torch.onnx.axis = 0 : si64 } : (!torch.vtensor <[2 ],f32 >) -> !torch.vtensor <[1 ,2 ],f32 >
1154+ return %0 : !torch.vtensor <[1 ,2 ],f32 >
1155+ }
1156+
1157+ // CHECK-LABEL: @test_flatten_1d_axis_negative_1
1158+ func.func @test_flatten_1d_axis_negative_1 (%arg0: !torch.vtensor <[2 ],f32 >) -> !torch.vtensor <[1 ,2 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1159+ // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0
1160+ // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0
1161+ // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor
1162+ // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0
1163+ // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,2],f32>
1164+ %0 = torch.operator " onnx.Flatten" (%arg0 ) {torch.onnx.axis = -1 : si64 } : (!torch.vtensor <[2 ],f32 >) -> !torch.vtensor <[1 ,2 ],f32 >
1165+ return %0 : !torch.vtensor <[1 ,2 ],f32 >
1166+ }
1167+
1168+ // COM: CHECK-LABEL: @test_flatten_1d_axis_1
1169+ func.func @test_flatten_1d_axis_1 (%arg0: !torch.vtensor <[2 ],f32 >) -> !torch.vtensor <[2 ,1 ],f32 > attributes {torch.onnx_meta.ir_version = 7 : si64 , torch.onnx_meta.opset_version = 13 : si64 , torch.onnx_meta.producer_name = " backend-test" , torch.onnx_meta.producer_version = " " } {
1170+ // CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 1
1171+ // CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor
1172+ // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0
1173+ // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0
1174+ // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[2,1],f32>
1175+ %0 = torch.operator " onnx.Flatten" (%arg0 ) {torch.onnx.axis = 1 : si64 } : (!torch.vtensor <[2 ],f32 >) -> !torch.vtensor <[2 ,1 ],f32 >
1176+ return %0 : !torch.vtensor <[2 ,1 ],f32 >
1177+ }
0 commit comments