11module ParametricExpressionModule
22
33using DispatchDoctor: @stable , @unstable
4+ using ChainRulesCore: ChainRulesCore, NoTangent
45
5- using .. OperatorEnumModule: AbstractOperatorEnum
6+ using .. OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
67using .. NodeModule: AbstractExpressionNode, Node, tree_mapreduce
78using .. ExpressionModule: AbstractExpression, Metadata
9+ using .. ChainRulesModule: NodeTangent
810
911import .. NodeModule: constructorof, preserve_sharing, leaf_copy, leaf_hash, leaf_equal
1012import .. NodeUtilsModule:
@@ -250,7 +252,7 @@ function (ex::ParametricExpression)(
250252 operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
251253 kws... ,
252254) where {T}
253- (output, flag) = eval_tree_array (ex, X, classes, operators; kws... ) # Will error
255+ (output, flag) = eval_tree_array (ex, X, classes, operators; kws... )
254256 if ! flag
255257 output .= NaN
256258 end
@@ -276,6 +278,71 @@ function eval_tree_array(
276278 regular_tree = convert (Node, ex)
277279 return eval_tree_array (regular_tree, params_and_X, get_operators (ex, operators); kws... )
278280end
281+ function ChainRulesCore. rrule (
282+ :: typeof (eval_tree_array),
283+ ex:: ParametricExpression{T} ,
284+ X:: AbstractMatrix{T} ,
285+ classes:: AbstractVector{<:Integer} ,
286+ operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
287+ kws... ,
288+ ) where {T}
289+ primal, complete = eval_tree_array (ex, X, classes, operators; kws... )
290+
291+ # TODO : Preferable to use the primal in the pullback somehow
292+ function pullback ((dY, _))
293+ parameters = ex. metadata. parameters
294+ num_params = size (parameters, 1 )
295+ num_classes = size (parameters, 2 )
296+ indexed_parameters = [
297+ parameters[i_parameter, classes[i_row]] for
298+ i_parameter in eachindex (axes (parameters, 1 )), i_row in eachindex (classes)
299+ ]
300+ params_and_X = vcat (indexed_parameters, X)
301+ tree = ex. tree
302+ regular_tree = convert (Node, ex)
303+
304+ _, gradient_tree, complete1 = eval_grad_tree_array (
305+ regular_tree, params_and_X, operators; variable= Val (false )
306+ )
307+ _, gradient_params_and_X, complete2 = eval_grad_tree_array (
308+ regular_tree, params_and_X, operators; variable= Val (true )
309+ )
310+
311+ if ! complete1
312+ gradient_tree .= NaN
313+ end
314+ if ! complete2
315+ gradient_params_and_X .= NaN
316+ end
317+
318+ d_tree = NodeTangent (
319+ tree,
320+ sum (j -> gradient_tree[:, j] * dY[j], eachindex (dY, axes (gradient_tree, 2 ))),
321+ )
322+ reshaped_d_Y = reshape (dY, 1 , length (dY))
323+ d_indexed_parameters = @view (gradient_params_and_X[1 : num_params, :]) .* reshaped_d_Y
324+ d_X = @view (gradient_params_and_X[(num_params + 1 ): end , :]) .* reshaped_d_Y
325+ d_parameters = [
326+ sum (
327+ j -> d_indexed_parameters[param, j] * dY[j] * (classes[j] == class),
328+ eachindex (classes, axes (d_indexed_parameters, 2 )),
329+ ) for param in 1 : num_params, class in 1 : num_classes
330+ ]
331+ d_ex = (;
332+ tree= d_tree,
333+ metadata= (;
334+ operators= NoTangent (),
335+ variable_names= NoTangent (),
336+ parameters= d_parameters,
337+ parameter_names= NoTangent (),
338+ ),
339+ )
340+ return (NoTangent (), d_ex, copy (d_X), NoTangent (), NoTangent ())
341+ end
342+
343+ return (primal, complete), pullback
344+ end
345+
279346function string_tree (
280347 ex:: ParametricExpression ,
281348 operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
0 commit comments