Skip to content

Commit 5f8c6a3

Browse files
committed
skip at-simd in favor of explicit re-commutating
1 parent 427bbb7 commit 5f8c6a3

File tree

4 files changed

+230
-49
lines changed

4 files changed

+230
-49
lines changed

base/cartesian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ function exprresolve_conditional(ex::Expr)
402402
return true, exprresolve_cond_dict[callee](ex.args[2], ex.args[3])
403403
end
404404
end
405-
elseif Meta.isexpr(ex, :block, 2) && ex.args[1] isa LineNumberNode
405+
elseif ex.head === :block && length(ex.args) == 2 && ex.args[1] isa LineNumberNode
406406
return exprresolve_conditional(ex.args[2])
407407
end
408408
false, false

base/multidimensional.jl

Lines changed: 160 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,16 +1049,15 @@ end
10491049
# Reduction support for CartesianIndices
10501050
halves(inds::CartesianIndices) = map(CartesianIndicestuple, _halves(inds.indices)...)
10511051
function mapreduce_kernel(f, op, A, init, inds::CartesianIndices{N}) where {N}
1052-
N == 0 && return init===_InitialValue() ? mapreduce_first(f, op, only(A)) : op(init, f(only(A)))
1052+
N == 0 && return _mapreduce_start(f, op, A, init, A[inds[]])
10531053
N == 1 && return mapreduce_kernel(f, op, A, init, inds.indices[1])
1054-
i1, s = iterate(inds)
1055-
a1 = @inbounds A[i1]
1056-
v = _mapreduce_start(f, op, A, init, a1)
1054+
is_commutative_op(op) && return mapreduce_kernel_commutative(f, op, A, init, inds)
10571055
r = inds.indices[1]
10581056
if length(r) == 1
1059-
# SIMD over a one-element loop is less-than-helpful; just iterate
1060-
# over the rest of the indices without worrying about splitting out
1061-
# an inner SIMD loop
1057+
# A one-element inner loop is less-than-helpful
1058+
i1, s = iterate(inds)
1059+
a1 = @inbounds A[i1]
1060+
v = _mapreduce_start(f, op, A, init, a1)
10621061
for i in Iterators.rest(inds, s)
10631062
ai = @inbounds A[i]
10641063
v = op(v, f(ai))
@@ -1068,12 +1067,13 @@ function mapreduce_kernel(f, op, A, init, inds::CartesianIndices{N}) where {N}
10681067
# in the first iteration of the outer loop
10691068
outer = CartesianIndices(tail(inds.indices))
10701069
o1, so = iterate(outer)
1071-
@simd for i in r[begin+1:end]
1070+
v = _mapreduce_start(f, op, A, init, A[r[begin], o1])
1071+
for i in r[begin+1:end]
10721072
ai = @inbounds A[i, o1]
10731073
v = op(v, f(ai))
10741074
end
10751075
for o in Iterators.rest(outer, so)
1076-
@simd for i in r
1076+
for i in r
10771077
ai = @inbounds A[i, o]
10781078
v = op(v, f(ai))
10791079
end
@@ -1082,6 +1082,157 @@ function mapreduce_kernel(f, op, A, init, inds::CartesianIndices{N}) where {N}
10821082
return v
10831083
end
10841084

1085+
function mapreduce_kernel_commutative(f, op, A, init, inds::AbstractArray)
1086+
if length(inds) < 16
1087+
i1, iN = firstindex(inds), lastindex(inds)
1088+
v_1 = _mapreduce_start(f, op, A, init, @inbounds A[inds[i1]])
1089+
for i in i1+1:iN
1090+
a = @inbounds A[inds[i]]
1091+
v_1 = op(v_1, f(a))
1092+
end
1093+
return v_1
1094+
end
1095+
return _mapreduce_kernel_commutative(f, op, A, init, inds)
1096+
end
1097+
1098+
# This special internal method must have at least 4 indices and allows passing
1099+
# optional scalar leading and trailing dimensions
1100+
function _mapreduce_kernel_commutative(f, op, A, init, inds, leading=(), trailing=())
1101+
i1, iN = firstindex(inds), lastindex(inds)
1102+
n = length(inds)
1103+
@nexprs 4 N->a_N = @inbounds A[leading..., inds[i1+(N-1)], trailing...]
1104+
@nexprs 4 N->v_N = _mapreduce_start(f, op, A, init, a_N)
1105+
for batch in 1:(n>>2)-1
1106+
i = i1 + batch*4
1107+
@nexprs 4 N->a_N = @inbounds A[leading..., inds[i+(N-1)], trailing...]
1108+
@nexprs 4 N->fa_N = f(a_N)
1109+
@nexprs 4 N->v_N = op(v_N, fa_N)
1110+
end
1111+
v = op(op(v_1, v_2), op(v_3, v_4))
1112+
i = i1 + (n>>2)*4 - 1
1113+
i == iN && return v
1114+
for i in i+1:iN
1115+
ai = @inbounds A[leading..., inds[i], trailing...]
1116+
v = op(v, f(ai))
1117+
end
1118+
return v
1119+
end
1120+
1121+
function mapreduce_kernel_commutative(f::F, op::G, A, init, inds::CartesianIndices{N}) where {N,F,G}
1122+
N == 0 && return _mapreduce_start(f, op, A, init, A[inds[]])
1123+
N == 1 && return mapreduce_kernel_commutative(f, op, A, init, inds.indices[1])
1124+
is = inds.indices[1]
1125+
js = inds.indices[2]
1126+
if length(is) == 1 && length(js) >= 4
1127+
# It's quite useful to optimize this case for dimensional reductions
1128+
i = only(is)
1129+
outer = CartesianIndices(tail(tail(inds.indices)))
1130+
o1, s = iterate(outer)
1131+
v = _mapreduce_kernel_commutative(f, op, A, init, js, (i,), o1.I)
1132+
for o in Iterators.rest(outer, s)
1133+
v = op(v, _mapreduce_kernel_commutative(f, op, A, init, js, (i,), o.I))
1134+
end
1135+
return v
1136+
elseif length(is) < 4 # TODO: tune this number
1137+
# These small cases could be further optimized
1138+
return mapreduce_kernel_commutative(i->f(A[i]), op, inds, init, HasShape{N}(), length(inds))[1]
1139+
else
1140+
outer = CartesianIndices(tail(inds.indices))
1141+
o1, s = iterate(outer)
1142+
v = _mapreduce_kernel_commutative(f, op, A, init, is, (), o1.I)
1143+
for o in Iterators.rest(outer, s)
1144+
v = op(v, _mapreduce_kernel_commutative(f, op, A, init, is, (), o.I))
1145+
end
1146+
return v
1147+
end
1148+
end
1149+
1150+
@noinline _throw_iterator_assertion_error() = throw(AssertionError("iterator promised a length longer than it iterates"))
1151+
function mapreduce_kernel_commutative(f, op, itr, init, ::Union{HasLength, HasShape}, n, state...)
1152+
it = iterate(itr, state...)
1153+
it === nothing && _throw_iterator_assertion_error()
1154+
a, s = it
1155+
v_1 = _mapreduce_start(f, op, itr, init, a)
1156+
if n < 16
1157+
for _ in 2:n
1158+
it = iterate(itr, s)
1159+
it === nothing && _throw_iterator_assertion_error()
1160+
a, s = it
1161+
v_1 = op(v_1, f(a))
1162+
end
1163+
return v_1, s
1164+
end
1165+
@nexprs 3 n->begin
1166+
it = iterate(itr, s)
1167+
it === nothing && _throw_iterator_assertion_error()
1168+
a, s = it
1169+
v_{n+1} = _mapreduce_start(f, op, itr, init, a)
1170+
end
1171+
i = 4
1172+
for outer i in 8:4:n
1173+
@nexprs 4 n->begin
1174+
it = iterate(itr, s)
1175+
it === nothing && _throw_iterator_assertion_error()
1176+
a_n, s = it
1177+
end
1178+
@nexprs 4 n-> fa_n = f(a_n)
1179+
@nexprs 4 n-> v_n = op(v_n, fa_n)
1180+
end
1181+
v = op(op(v_1, v_2), op(v_3, v_4))
1182+
for _ in i+1:n
1183+
it = iterate(itr, s)
1184+
it === nothing && _throw_iterator_assertion_error()
1185+
a, s = it
1186+
v = op(v, f(a))
1187+
end
1188+
return v, s
1189+
end
1190+
function mapreduce_kernel_commutative(f, op, itr, init, ::IteratorSize, n, state...)
1191+
it = iterate(itr, state...)
1192+
it === nothing && return nothing
1193+
a, s = it
1194+
v_1 = _mapreduce_start(f, op, itr, init, a)
1195+
it = iterate(itr, s)
1196+
it === nothing && return Some(v_1)
1197+
a, s = it
1198+
v_2 = _mapreduce_start(f, op, itr, init, a)
1199+
it = iterate(itr, s)
1200+
it === nothing && return Some(op(v_1, v_2))
1201+
a, s = it
1202+
v_3 = _mapreduce_start(f, op, itr, init, a)
1203+
it = iterate(itr, s)
1204+
it === nothing && return Some(op(op(v_1, v_2), v_3))
1205+
a, s = it
1206+
v_4 = _mapreduce_start(f, op, itr, init, a)
1207+
for _ in 2:n>>2
1208+
@nexprs 4 N->begin
1209+
it = iterate(itr, s)
1210+
if it === nothing
1211+
N > 3 && (v_3 = op(v_3, f(a_3)))
1212+
N > 2 && (v_2 = op(v_2, f(a_2)))
1213+
N > 1 && (v_1 = op(v_1, f(a_1)))
1214+
return Some(op(op(v_1, v_2), op(v_3, v_4)))
1215+
end
1216+
a_N, s = it
1217+
end
1218+
@nexprs 4 N->fa_N = f(a_N)
1219+
@nexprs 4 N->v_N = op(v_N, fa_N)
1220+
end
1221+
v = op(op(v_1, v_2), op(v_3, v_4))
1222+
i = (n>>2)*4
1223+
@nexprs 4 N->begin
1224+
it = iterate(itr, s)
1225+
if it === nothing
1226+
return Some(v)
1227+
elseif n < i+N
1228+
return (v, s)
1229+
end
1230+
a, s = it
1231+
v = op(v, f(a))
1232+
end
1233+
return (v, s)
1234+
end
1235+
10851236
diff(a::AbstractVector) = diff(a, dims=1)
10861237

10871238
"""

base/permuteddimsarray.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,8 +343,7 @@ end
343343
P
344344
end
345345

346-
const CommutativeOps = Union{typeof(+),typeof(Base.add_sum),typeof(min),typeof(max),typeof(Base._extrema_rf),typeof(|),typeof(&)}
347-
346+
using Base: CommutativeOps
348347
function Base.mapreducedim(f, op::CommutativeOps, A::PermutedDimsArray, init, dims::Colon)
349348
Base.mapreducedim(f, op, parent(A), init, dims)
350349
end

0 commit comments

Comments
 (0)