22module ExpressionModule
33
44using DispatchDoctor: @unstable
5- using ChainRulesCore: @ignore_derivatives
5+ using ChainRulesCore: CRC
66
77using .. NodeModule: AbstractExpressionNode, Node
88using .. OperatorEnumModule: AbstractOperatorEnum, OperatorEnum
@@ -19,7 +19,12 @@ import ..NodeUtilsModule:
1919 has_constants,
2020 get_constants,
2121 set_constants!
22+ import .. EvaluateModule: eval_tree_array, differentiable_eval_tree_array
23+ import .. EvaluateDerivativeModule: eval_grad_tree_array
24+ import .. EvaluationHelpersModule: _grad_evaluator
25+ import .. StringsModule: string_tree, print_tree
2226import .. ChainRulesModule: extract_gradient
27+ import .. SimplifyModule: combine_operators, simplify_tree!
2328
2429""" A wrapper for a named tuple to avoid piracy."""
2530struct Metadata{NT<: NamedTuple }
@@ -280,8 +285,6 @@ function extract_gradient(
280285 return extract_gradient (gradient. tree, get_tree (ex))
281286end
282287
283- import .. StringsModule: string_tree, print_tree
284-
285288function string_tree (
286289 ex:: AbstractExpression ,
287290 operators:: Union{AbstractOperatorEnum,Nothing} = nothing ;
@@ -311,8 +314,6 @@ function Base.show(io::IO, ::MIME"text/plain", ex::AbstractExpression)
311314 return print (io, string_tree (ex))
312315end
313316
314- import .. EvaluateModule: eval_tree_array, differentiable_eval_tree_array
315-
316317function max_feature (ex:: AbstractExpression )
317318 return tree_mapreduce (
318319 leaf -> leaf. constant ? zero (UInt16) : leaf. feature,
@@ -344,8 +345,6 @@ function eval_tree_array(
344345 return eval_tree_array (get_tree (ex), cX, get_operators (ex, operators); kws... )
345346end
346347
347- import .. EvaluateDerivativeModule: eval_grad_tree_array
348-
349348# skipped (not used much)
350349# - eval_diff_tree_array
351350# - differentiable_eval_tree_array
@@ -360,8 +359,6 @@ function eval_grad_tree_array(
360359 return eval_grad_tree_array (get_tree (ex), cX, get_operators (ex, operators); kws... )
361360end
362361
363- import .. EvaluationHelpersModule: _grad_evaluator
364-
365362function Base. adjoint (ex:: AbstractExpression )
366363 return ((args... ; kws... ) -> _grad_evaluator (ex, args... ; kws... ))
367364end
@@ -382,6 +379,4 @@ function (ex::AbstractExpression)(
382379 return get_tree (ex)(X, get_operators (ex, operators); kws... )
383380end
384381
385- import .. SimplifyModule: combine_operators, simplify_tree!
386-
387382end
0 commit comments