Skip to content

Commit 38ec76f

Browse files
committed
Improve readability in Bumper interface
1 parent a900883 commit 38ec76f

File tree

1 file changed

+21
-29
lines changed

1 file changed

+21
-29
lines changed

ext/DynamicExpressionsBumperExt.jl

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,46 +15,23 @@ function bumper_eval_tree_array(
1515
_result_ok = tree_mapreduce(
1616
# Leaf nodes, we create an allocation and fill
1717
# it with the value of the leaf:
18-
leaf -> begin
18+
leaf_node -> begin
1919
ar = @alloc(T, n)
20-
ok = if leaf.constant
21-
v = leaf.val::T
20+
ok = if leaf_node.constant
21+
v = leaf_node.val::T
2222
ar .= v
2323
isfinite(v)
2424
else
25-
ar .= view(cX, leaf.feature, :)
25+
ar .= view(cX, leaf_node.feature, :)
2626
true
2727
end
2828
ResultOk(ar, ok)
2929
end,
3030
# Branch nodes, we simply pass them to the evaluation kernel:
31-
branch -> branch,
31+
branch_node -> branch_node,
3232
# In the evaluation kernel, we combine the branch nodes
3333
# with the arrays created by the leaf nodes:
34-
((branch, cumulators::Vararg{Any,M}) where {M}) -> begin
35-
if M == 1
36-
if cumulators[1].ok
37-
out = dispatch_kern1!(operators.unaops, branch.op, cumulators[1].x)
38-
ResultOk(out, !is_bad_array(out))
39-
else
40-
cumulators[1]
41-
end
42-
else
43-
if cumulators[1].ok && cumulators[2].ok
44-
out = dispatch_kern2!(
45-
operators.binops,
46-
branch.op,
47-
cumulators[1].x,
48-
cumulators[2].x,
49-
)
50-
ResultOk(out, !is_bad_array(out))
51-
elseif cumulators[1].ok
52-
cumulators[2]
53-
else
54-
cumulators[1]
55-
end
56-
end
57-
end,
34+
((args::Vararg{Any,M}) where {M}) -> dispatch_kerns!(operators, args...),
5835
tree;
5936
break_sharing=Val(true),
6037
)
@@ -64,6 +41,21 @@ function bumper_eval_tree_array(
6441
end
6542
return (result, ok)
6643
end
44+
45+
function dispatch_kerns!(operators, branch_node, cumulator)
46+
cumulator.ok || return cumulator
47+
48+
out = dispatch_kern1!(operators.unaops, branch_node.op, cumulator.x)
49+
return ResultOk(out, !is_bad_array(out))
50+
end
51+
function dispatch_kerns!(operators, branch_node, cumulator1, cumulator2)
52+
cumulator1.ok || return cumulator1
53+
cumulator2.ok || return cumulator2
54+
55+
out = dispatch_kern2!(operators.binops, branch_node.op, cumulator1.x, cumulator2.x)
56+
return ResultOk(out, !is_bad_array(out))
57+
end
58+
6759
@generated function dispatch_kern1!(unaops, op_idx, cumulator)
6860
nuna = counttuple(unaops)
6961
quote

0 commit comments

Comments
 (0)