@@ -15,46 +15,23 @@ function bumper_eval_tree_array(
1515 _result_ok = tree_mapreduce (
1616 # Leaf nodes, we create an allocation and fill
1717 # it with the value of the leaf:
18- leaf -> begin
18+ leaf_node -> begin
1919 ar = @alloc (T, n)
20- ok = if leaf . constant
21- v = leaf . val:: T
20+ ok = if leaf_node . constant
21+ v = leaf_node . val:: T
2222 ar .= v
2323 isfinite (v)
2424 else
25- ar .= view (cX, leaf . feature, :)
25+ ar .= view (cX, leaf_node . feature, :)
2626 true
2727 end
2828 ResultOk (ar, ok)
2929 end ,
3030 # Branch nodes, we simply pass them to the evaluation kernel:
31- branch -> branch ,
31+ branch_node -> branch_node ,
3232 # In the evaluation kernel, we combine the branch nodes
3333 # with the arrays created by the leaf nodes:
34- ((branch, cumulators:: Vararg{Any,M} ) where {M}) -> begin
35- if M == 1
36- if cumulators[1 ]. ok
37- out = dispatch_kern1! (operators. unaops, branch. op, cumulators[1 ]. x)
38- ResultOk (out, ! is_bad_array (out))
39- else
40- cumulators[1 ]
41- end
42- else
43- if cumulators[1 ]. ok && cumulators[2 ]. ok
44- out = dispatch_kern2! (
45- operators. binops,
46- branch. op,
47- cumulators[1 ]. x,
48- cumulators[2 ]. x,
49- )
50- ResultOk (out, ! is_bad_array (out))
51- elseif cumulators[1 ]. ok
52- cumulators[2 ]
53- else
54- cumulators[1 ]
55- end
56- end
57- end ,
34+ ((args:: Vararg{Any,M} ) where {M}) -> dispatch_kerns! (operators, args... ),
5835 tree;
5936 break_sharing= Val (true ),
6037 )
@@ -64,6 +41,21 @@ function bumper_eval_tree_array(
6441 end
6542 return (result, ok)
6643end
44+
45+ function dispatch_kerns! (operators, branch_node, cumulator)
46+ cumulator. ok || return cumulator
47+
48+ out = dispatch_kern1! (operators. unaops, branch_node. op, cumulator. x)
49+ return ResultOk (out, ! is_bad_array (out))
50+ end
51+ function dispatch_kerns! (operators, branch_node, cumulator1, cumulator2)
52+ cumulator1. ok || return cumulator1
53+ cumulator2. ok || return cumulator2
54+
55+ out = dispatch_kern2! (operators. binops, branch_node. op, cumulator1. x, cumulator2. x)
56+ return ResultOk (out, ! is_bad_array (out))
57+ end
58+
6759@generated function dispatch_kern1! (unaops, op_idx, cumulator)
6860 nuna = counttuple (unaops)
6961 quote
0 commit comments