@@ -1049,16 +1049,15 @@ end
10491049# Reduction support for CartesianIndices
10501050halves (inds:: CartesianIndices ) = map (CartesianIndices∘ tuple, _halves (inds. indices)... )
10511051function mapreduce_kernel (f, op, A, init, inds:: CartesianIndices{N} ) where {N}
1052- N == 0 && return init === _InitialValue () ? mapreduce_first ( f, op, only (A)) : op ( init, f ( only (A)) )
1052+ N == 0 && return _mapreduce_start ( f, op, A, init, A[inds[]] )
10531053 N == 1 && return mapreduce_kernel (f, op, A, init, inds. indices[1 ])
1054- i1, s = iterate (inds)
1055- a1 = @inbounds A[i1]
1056- v = _mapreduce_start (f, op, A, init, a1)
1054+ is_commutative_op (op) && return mapreduce_kernel_commutative (f, op, A, init, inds)
10571055 r = inds. indices[1 ]
10581056 if length (r) == 1
1059- # SIMD over a one-element loop is less-than-helpful; just iterate
1060- # over the rest of the indices without worrying about splitting out
1061- # an inner SIMD loop
1057+ # A one-element inner loop is less-than-helpful
1058+ i1, s = iterate (inds)
1059+ a1 = @inbounds A[i1]
1060+ v = _mapreduce_start (f, op, A, init, a1)
10621061 for i in Iterators. rest (inds, s)
10631062 ai = @inbounds A[i]
10641063 v = op (v, f (ai))
@@ -1068,12 +1067,13 @@ function mapreduce_kernel(f, op, A, init, inds::CartesianIndices{N}) where {N}
10681067 # in the first iteration of the outer loop
10691068 outer = CartesianIndices (tail (inds. indices))
10701069 o1, so = iterate (outer)
1071- @simd for i in r[begin + 1 : end ]
1070+ v = _mapreduce_start (f, op, A, init, A[r[begin ], o1])
1071+ for i in r[begin + 1 : end ]
10721072 ai = @inbounds A[i, o1]
10731073 v = op (v, f (ai))
10741074 end
10751075 for o in Iterators. rest (outer, so)
1076- @simd for i in r
1076+ for i in r
10771077 ai = @inbounds A[i, o]
10781078 v = op (v, f (ai))
10791079 end
@@ -1082,6 +1082,157 @@ function mapreduce_kernel(f, op, A, init, inds::CartesianIndices{N}) where {N}
10821082 return v
10831083end
10841084
1085+ function mapreduce_kernel_commutative (f, op, A, init, inds:: AbstractArray )
1086+ if length (inds) < 16
1087+ i1, iN = firstindex (inds), lastindex (inds)
1088+ v_1 = _mapreduce_start (f, op, A, init, @inbounds A[inds[i1]])
1089+ for i in i1+ 1 : iN
1090+ a = @inbounds A[inds[i]]
1091+ v_1 = op (v_1, f (a))
1092+ end
1093+ return v_1
1094+ end
1095+ return _mapreduce_kernel_commutative (f, op, A, init, inds)
1096+ end
1097+
1098+ # This special internal method must have at least 4 indices and allows passing
1099+ # optional scalar leading and trailing dimensions
1100+ function _mapreduce_kernel_commutative (f, op, A, init, inds, leading= (), trailing= ())
1101+ i1, iN = firstindex (inds), lastindex (inds)
1102+ n = length (inds)
1103+ @nexprs 4 N-> a_N = @inbounds A[leading... , inds[i1+ (N- 1 )], trailing... ]
1104+ @nexprs 4 N-> v_N = _mapreduce_start (f, op, A, init, a_N)
1105+ for batch in 1 : (n>> 2 )- 1
1106+ i = i1 + batch* 4
1107+ @nexprs 4 N-> a_N = @inbounds A[leading... , inds[i+ (N- 1 )], trailing... ]
1108+ @nexprs 4 N-> fa_N = f (a_N)
1109+ @nexprs 4 N-> v_N = op (v_N, fa_N)
1110+ end
1111+ v = op (op (v_1, v_2), op (v_3, v_4))
1112+ i = i1 + (n>> 2 )* 4 - 1
1113+ i == iN && return v
1114+ for i in i+ 1 : iN
1115+ ai = @inbounds A[leading... , inds[i], trailing... ]
1116+ v = op (v, f (ai))
1117+ end
1118+ return v
1119+ end
1120+
1121+ function mapreduce_kernel_commutative (f:: F , op:: G , A, init, inds:: CartesianIndices{N} ) where {N,F,G}
1122+ N == 0 && return _mapreduce_start (f, op, A, init, A[inds[]])
1123+ N == 1 && return mapreduce_kernel_commutative (f, op, A, init, inds. indices[1 ])
1124+ is = inds. indices[1 ]
1125+ js = inds. indices[2 ]
1126+ if length (is) == 1 && length (js) >= 4
1127+ # It's quite useful to optimize this case for dimensional reductions
1128+ i = only (is)
1129+ outer = CartesianIndices (tail (tail (inds. indices)))
1130+ o1, s = iterate (outer)
1131+ v = _mapreduce_kernel_commutative (f, op, A, init, js, (i,), o1. I)
1132+ for o in Iterators. rest (outer, s)
1133+ v = op (v, _mapreduce_kernel_commutative (f, op, A, init, js, (i,), o. I))
1134+ end
1135+ return v
1136+ elseif length (is) < 4 # TODO : tune this number
1137+ # These small cases could be further optimized
1138+ return mapreduce_kernel_commutative (i-> f (A[i]), op, inds, init, HasShape {N} (), length (inds))[1 ]
1139+ else
1140+ outer = CartesianIndices (tail (inds. indices))
1141+ o1, s = iterate (outer)
1142+ v = _mapreduce_kernel_commutative (f, op, A, init, is, (), o1. I)
1143+ for o in Iterators. rest (outer, s)
1144+ v = op (v, _mapreduce_kernel_commutative (f, op, A, init, is, (), o. I))
1145+ end
1146+ return v
1147+ end
1148+ end
1149+
1150+ @noinline _throw_iterator_assertion_error () = throw (AssertionError (" iterator promised a length longer than it iterates" ))
1151+ function mapreduce_kernel_commutative (f, op, itr, init, :: Union{HasLength, HasShape} , n, state... )
1152+ it = iterate (itr, state... )
1153+ it === nothing && _throw_iterator_assertion_error ()
1154+ a, s = it
1155+ v_1 = _mapreduce_start (f, op, itr, init, a)
1156+ if n < 16
1157+ for _ in 2 : n
1158+ it = iterate (itr, s)
1159+ it === nothing && _throw_iterator_assertion_error ()
1160+ a, s = it
1161+ v_1 = op (v_1, f (a))
1162+ end
1163+ return v_1, s
1164+ end
1165+ @nexprs 3 n-> begin
1166+ it = iterate (itr, s)
1167+ it === nothing && _throw_iterator_assertion_error ()
1168+ a, s = it
1169+ v_{n+ 1 } = _mapreduce_start (f, op, itr, init, a)
1170+ end
1171+ i = 4
1172+ for outer i in 8 : 4 : n
1173+ @nexprs 4 n-> begin
1174+ it = iterate (itr, s)
1175+ it === nothing && _throw_iterator_assertion_error ()
1176+ a_n, s = it
1177+ end
1178+ @nexprs 4 n-> fa_n = f (a_n)
1179+ @nexprs 4 n-> v_n = op (v_n, fa_n)
1180+ end
1181+ v = op (op (v_1, v_2), op (v_3, v_4))
1182+ for _ in i+ 1 : n
1183+ it = iterate (itr, s)
1184+ it === nothing && _throw_iterator_assertion_error ()
1185+ a, s = it
1186+ v = op (v, f (a))
1187+ end
1188+ return v, s
1189+ end
1190+ function mapreduce_kernel_commutative (f, op, itr, init, :: IteratorSize , n, state... )
1191+ it = iterate (itr, state... )
1192+ it === nothing && return nothing
1193+ a, s = it
1194+ v_1 = _mapreduce_start (f, op, itr, init, a)
1195+ it = iterate (itr, s)
1196+ it === nothing && return Some (v_1)
1197+ a, s = it
1198+ v_2 = _mapreduce_start (f, op, itr, init, a)
1199+ it = iterate (itr, s)
1200+ it === nothing && return Some (op (v_1, v_2))
1201+ a, s = it
1202+ v_3 = _mapreduce_start (f, op, itr, init, a)
1203+ it = iterate (itr, s)
1204+ it === nothing && return Some (op (op (v_1, v_2), v_3))
1205+ a, s = it
1206+ v_4 = _mapreduce_start (f, op, itr, init, a)
1207+ for _ in 2 : n>> 2
1208+ @nexprs 4 N-> begin
1209+ it = iterate (itr, s)
1210+ if it === nothing
1211+ N > 3 && (v_3 = op (v_3, f (a_3)))
1212+ N > 2 && (v_2 = op (v_2, f (a_2)))
1213+ N > 1 && (v_1 = op (v_1, f (a_1)))
1214+ return Some (op (op (v_1, v_2), op (v_3, v_4)))
1215+ end
1216+ a_N, s = it
1217+ end
1218+ @nexprs 4 N-> fa_N = f (a_N)
1219+ @nexprs 4 N-> v_N = op (v_N, fa_N)
1220+ end
1221+ v = op (op (v_1, v_2), op (v_3, v_4))
1222+ i = (n>> 2 )* 4
1223+ @nexprs 4 N-> begin
1224+ it = iterate (itr, s)
1225+ if it === nothing
1226+ return Some (v)
1227+ elseif n < i+ N
1228+ return (v, s)
1229+ end
1230+ a, s = it
1231+ v = op (v, f (a))
1232+ end
1233+ return (v, s)
1234+ end
1235+
10851236diff (a:: AbstractVector ) = diff (a, dims= 1 )
10861237
10871238"""
0 commit comments