@@ -94,34 +94,43 @@ This holds options for expression evaluation, such as evaluation backend.
9494- `buffer::Union{ArrayBuffer,Nothing}`: If not `nothing`, use this buffer for evaluation.
9595 This should be an instance of `ArrayBuffer` which has an `array` field and an
9696 `index` field used to iterate which buffer slot to use.
97+ - `use_fused::Val{U}=Val(true)`: If `Val{true}`, use fused kernels for faster
98+ evaluation. Setting this to `Val{false}` will skip the fused kernels, meaning that
99+ you would only need to overload `deg0_eval`, `deg1_eval` and `deg2_eval` for custom
100+ evaluation.
97101"""
98- struct EvalOptions{T,B,E,BUF<: Union{ArrayBuffer,Nothing} }
102+ struct EvalOptions{T,B,E,BUF<: Union{ArrayBuffer,Nothing} ,U }
99103 turbo:: Val{T}
100104 bumper:: Val{B}
101105 early_exit:: Val{E}
102106 buffer:: BUF
107+ use_fused:: Val{U}
103108end
104109
105110@unstable function EvalOptions (;
106111 turbo:: Union{Bool,Val} = Val (false ),
107112 bumper:: Union{Bool,Val} = Val (false ),
108113 early_exit:: Union{Bool,Val} = Val (true ),
109114 buffer:: Union{ArrayBuffer,Nothing} = nothing ,
115+ use_fused:: Union{Bool,Val} = Val (true ),
110116)
111117 v_turbo = _to_bool_val (turbo)
112118 v_bumper = _to_bool_val (bumper)
113119 v_early_exit = _to_bool_val (early_exit)
120+ v_use_fused = _to_bool_val (use_fused)
114121
115122 if v_bumper isa Val{true }
116123 @assert buffer === nothing
117124 end
118125
119- return EvalOptions (v_turbo, v_bumper, v_early_exit, buffer)
126+ return EvalOptions (v_turbo, v_bumper, v_early_exit, buffer, v_use_fused )
120127end
121128
122129@unstable @inline _to_bool_val (x:: Bool ) = x ? Val (true ) : Val (false )
123130@inline _to_bool_val (:: Val{T} ) where {T} = Val (T:: Bool )
124131
132+ @inline use_fused (eval_options:: EvalOptions ) = eval_options. use_fused isa Val{true }
133+
125134_copy (x) = copy (x)
126135_copy (:: Nothing ) = nothing
127136function Base. copy (eval_options:: EvalOptions )
@@ -130,6 +139,7 @@ function Base.copy(eval_options::EvalOptions)
130139 bumper= eval_options. bumper,
131140 early_exit= eval_options. early_exit,
132141 buffer= _copy (eval_options. buffer),
142+ use_fused= eval_options. use_fused,
133143 )
134144end
135145
@@ -433,19 +443,20 @@ end
433443 end
434444 end
435445 return quote
446+ fused = use_fused (eval_options)
436447 return Base. Cartesian. @nif (
437448 $ nbin,
438449 i -> i == op_idx, # COV_EXCL_LINE
439450 i -> let op = operators. binops[i] # COV_EXCL_LINE
440- if get_child (tree, 1 ). degree == 0 && get_child (tree, 2 ). degree == 0
451+ if fused && get_child (tree, 1 ). degree == 0 && get_child (tree, 2 ). degree == 0
441452 deg2_l0_r0_eval (tree, cX, op, eval_options)
442- elseif get_child (tree, 2 ). degree == 0
453+ elseif fused && get_child (tree, 2 ). degree == 0
443454 result_l = _eval_tree_array (get_child (tree, 1 ), cX, operators, eval_options)
444455 ! result_l. ok && return result_l
445456 @return_on_nonfinite_array (eval_options, result_l. x)
446457 # op(x, y), where y is a constant or variable but x is not.
447458 deg2_r0_eval (tree, result_l. x, cX, op, eval_options)
448- elseif get_child (tree, 1 ). degree == 0
459+ elseif fused && get_child (tree, 1 ). degree == 0
449460 result_r = _eval_tree_array (get_child (tree, 2 ), cX, operators, eval_options)
450461 ! result_r. ok && return result_r
451462 @return_on_nonfinite_array (eval_options, result_r. x)
@@ -487,19 +498,22 @@ end
487498 # This @nif lets us generate an if statement over choice of operator,
488499 # which means the compiler will be able to completely avoid type inference on operators.
489500 return quote
501+ fused = use_fused (eval_options)
490502 Base. Cartesian. @nif (
491503 $ nuna,
492504 i -> i == op_idx, # COV_EXCL_LINE
493505 i -> let op = operators. unaops[i] # COV_EXCL_LINE
494- if get_child (tree, 1 ). degree == 2 &&
506+ if fused &&
507+ get_child (tree, 1 ). degree == 2 &&
495508 get_child (get_child (tree, 1 ), 1 ). degree == 0 &&
496509 get_child (get_child (tree, 1 ), 2 ). degree == 0
497510 # op(op2(x, y)), where x, y, z are constants or variables.
498511 l_op_idx = get_child (tree, 1 ). op
499512 dispatch_deg1_l2_ll0_lr0_eval (
500513 tree, cX, op, l_op_idx, operators. binops, eval_options
501514 )
502- elseif get_child (tree, 1 ). degree == 1 &&
515+ elseif fused &&
516+ get_child (tree, 1 ). degree == 1 &&
503517 get_child (get_child (tree, 1 ), 1 ). degree == 0
504518 # op(op2(x)), where x is a constant or variable.
505519 l_op_idx = get_child (tree, 1 ). op
0 commit comments