@@ -206,17 +206,33 @@ function eval_grad_tree_array(
206206 variable:: Union{Bool,Val} = Val (false ),
207207 turbo:: Union{Bool,Val} = Val (false ),
208208) where {T<: Number }
209- n_gradients = if isa (variable, Val{true }) || (isa (variable, Bool) && variable)
209+ variable_mode = isa (variable, Val{true }) || (isa (variable, Bool) && variable)
210+ constant_mode = isa (variable, Val{false }) || (isa (variable, Bool) && ! variable)
211+ both_mode = isa (variable, Val{:both })
212+
213+ n_gradients = if variable_mode
210214 size (cX, 1 ):: Int
211- else
215+ elseif constant_mode
212216 count_constants (tree):: Int
217+ elseif both_mode
218+ size (cX, 1 ) + count_constants (tree)
213219 end
214- result = if isa (variable, Val{true }) || (variable isa Bool && variable)
220+
221+ result = if variable_mode
215222 eval_grad_tree_array (tree, n_gradients, nothing , cX, operators, Val (true ))
216- else
223+ elseif constant_mode
217224 index_tree = index_constants (tree)
218- eval_grad_tree_array (tree, n_gradients, index_tree, cX, operators, Val (false ))
219- end
225+ eval_grad_tree_array (
226+ tree, n_gradients, index_tree, cX, operators, Val (false )
227+ )
228+ elseif both_mode
229+ # features come first because we can use size(cX, 1) to skip them
230+ index_tree = index_constants (tree)
231+ eval_grad_tree_array (
232+ tree, n_gradients, index_tree, cX, operators, Val (:both )
233+ )
234+ end :: ResultOk2
235+
220236 return (result. x, result. dx, result. ok)
221237end
222238
@@ -226,11 +242,9 @@ function eval_grad_tree_array(
226242 index_tree:: Union{NodeIndex,Nothing} ,
227243 cX:: AbstractMatrix{T} ,
228244 operators:: OperatorEnum ,
229- :: Val{variable} ,
230- ):: ResultOk2 where {T<: Number ,variable}
231- result = _eval_grad_tree_array (
232- tree, n_gradients, index_tree, cX, operators, Val (variable)
233- )
245+ :: Val{mode} ,
246+ ):: ResultOk2 where {T<: Number ,mode}
247+ result = _eval_grad_tree_array (tree, n_gradients, index_tree, cX, operators, Val (mode))
234248 ! result. ok && return result
235249 return ResultOk2 (
236250 result. x, result. dx, ! (is_bad_array (result. x) || is_bad_array (result. dx))
@@ -260,30 +274,18 @@ end
260274 index_tree:: Union{NodeIndex,Nothing} ,
261275 cX:: AbstractMatrix{T} ,
262276 operators:: OperatorEnum ,
263- :: Val{variable } ,
264- ):: ResultOk2 where {T<: Number ,variable }
277+ :: Val{mode } ,
278+ ):: ResultOk2 where {T<: Number ,mode }
265279 nuna = get_nuna (operators)
266280 nbin = get_nbin (operators)
267281 deg1_branch_skeleton = quote
268282 grad_deg1_eval (
269- tree,
270- n_gradients,
271- index_tree,
272- cX,
273- operators. unaops[i],
274- operators,
275- Val (variable),
283+ tree, n_gradients, index_tree, cX, operators. unaops[i], operators, Val (mode)
276284 )
277285 end
278286 deg2_branch_skeleton = quote
279287 grad_deg2_eval (
280- tree,
281- n_gradients,
282- index_tree,
283- cX,
284- operators. binops[i],
285- operators,
286- Val (variable),
288+ tree, n_gradients, index_tree, cX, operators. binops[i], operators, Val (mode)
287289 )
288290 end
289291 deg1_branch = if nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
310312 end
311313 quote
312314 if tree. degree == 0
313- grad_deg0_eval (tree, n_gradients, index_tree, cX, Val (variable ))
315+ grad_deg0_eval (tree, n_gradients, index_tree, cX, Val (mode ))
314316 elseif tree. degree == 1
315317 $ deg1_branch
316318 else
@@ -324,8 +326,8 @@ function grad_deg0_eval(
324326 n_gradients,
325327 index_tree:: Union{NodeIndex,Nothing} ,
326328 cX:: AbstractMatrix{T} ,
327- :: Val{variable } ,
328- ):: ResultOk2 where {T<: Number ,variable }
329+ :: Val{mode } ,
330+ ):: ResultOk2 where {T<: Number ,mode }
329331 const_part = deg0_eval (tree, cX). x
330332
331333 zero_mat = if isa (cX, Array)
@@ -334,17 +336,26 @@ function grad_deg0_eval(
334336 hcat ([fill_similar (zero (T), cX, axes (cX, 2 )) for _ in 1 : n_gradients]. .. )'
335337 end
336338
337- if variable == tree. constant
339+ if (mode isa Bool && mode == tree. constant)
340+ # No gradients at this leaf node
338341 return ResultOk2 (const_part, zero_mat, true )
339342 end
340343
341- index = if variable
342- tree. feature
343- else
344+ index = if (mode isa Bool && mode)
345+ tree. feature:: UInt16
346+ elseif (mode isa Bool && ! mode)
344347 (index_tree === nothing ? zero (UInt16) : index_tree. val:: UInt16 )
348+ elseif mode == :both
349+ index_tree:: NodeIndex
350+ if tree. constant
351+ index_tree. val:: UInt16 + UInt16 (size (cX, 1 ))
352+ else
353+ tree. feature:: UInt16
354+ end
345355 end
356+
346357 derivative_part = zero_mat
347- derivative_part[index, :] . = one (T)
358+ fill! ( @view ( derivative_part[index, :]), one (T) )
348359 return ResultOk2 (const_part, derivative_part, true )
349360end
350361
@@ -355,15 +366,15 @@ function grad_deg1_eval(
355366 cX:: AbstractMatrix{T} ,
356367 op:: F ,
357368 operators:: OperatorEnum ,
358- :: Val{variable } ,
359- ):: ResultOk2 where {T<: Number ,F,variable }
369+ :: Val{mode } ,
370+ ):: ResultOk2 where {T<: Number ,F,mode }
360371 result = eval_grad_tree_array (
361372 tree. l,
362373 n_gradients,
363374 index_tree === nothing ? index_tree : index_tree. l,
364375 cX,
365376 operators,
366- Val (variable ),
377+ Val (mode ),
367378 )
368379 ! result. ok && return result
369380
@@ -389,15 +400,15 @@ function grad_deg2_eval(
389400 cX:: AbstractMatrix{T} ,
390401 op:: F ,
391402 operators:: OperatorEnum ,
392- :: Val{variable } ,
393- ):: ResultOk2 where {T<: Number ,F,variable }
403+ :: Val{mode } ,
404+ ):: ResultOk2 where {T<: Number ,F,mode }
394405 result_l = eval_grad_tree_array (
395406 tree. l,
396407 n_gradients,
397408 index_tree === nothing ? index_tree : index_tree. l,
398409 cX,
399410 operators,
400- Val (variable ),
411+ Val (mode ),
401412 )
402413 ! result_l. ok && return result_l
403414 result_r = eval_grad_tree_array (
@@ -406,7 +417,7 @@ function grad_deg2_eval(
406417 index_tree === nothing ? index_tree : index_tree. r,
407418 cX,
408419 operators,
409- Val (variable ),
420+ Val (mode ),
410421 )
411422 ! result_r. ok && return result_r
412423
0 commit comments