@@ -6,82 +6,135 @@ import ..EvaluateEquationModule: eval_tree_array
66import .. EvaluateEquationDerivativeModule: eval_grad_tree_array, _zygote_gradient
77import .. EvaluationHelpersModule: _grad_evaluator
88
9- function create_evaluation_helpers! (operators:: OperatorEnum )
10- @eval begin
11- Base. print (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
12- Base. show (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
13- function (tree:: Node )(X; kws... )
14- Base. depwarn (
15- " The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead." ,
16- :Node ,
17- )
18- return tree (X, $ operators; kws... )
19- end
20- # Gradients:
21- function _grad_evaluator (tree:: Node , X; kws... )
22- Base. depwarn (
23- " The `tree'(X; kws...)` syntax is deprecated. Use `tree'(X, operators; kws...)` instead." ,
24- :Node ,
25- )
26- return _grad_evaluator (tree, X, $ operators; kws... )
27- end
9+ """ Used to set a default value for `operators` for ease of use."""
10+ @enum AvailableOperatorTypes begin
11+ IsNothing
12+ IsOperatorEnum
13+ IsGenericOperatorEnum
14+ end
15+
16+ # These constants are purely for convenience. Internal code
17+ # should make use of `Node`, `string_tree`, `eval_tree_array`,
18+ # and `eval_grad_tree_array` directly.
19+
20+ const LATEST_OPERATORS = Ref {Union{Nothing,AbstractOperatorEnum}} (nothing )
21+ const LATEST_OPERATORS_TYPE = Ref {AvailableOperatorTypes} (IsNothing)
22+ const LATEST_UNARY_OPERATOR_MAPPING = Dict {Function,Int} ()
23+ const LATEST_BINARY_OPERATOR_MAPPING = Dict {Function,Int} ()
24+ const ALREADY_DEFINED_UNARY_OPERATORS = (;
25+ operator_enum= Dict {Function,Bool} (), generic_operator_enum= Dict {Function,Bool} ()
26+ )
27+ const ALREADY_DEFINED_BINARY_OPERATORS = (;
28+ operator_enum= Dict {Function,Bool} (), generic_operator_enum= Dict {Function,Bool} ()
29+ )
30+
31+ function Base. show (io:: IO , tree:: Node )
32+ latest_operators_type = LATEST_OPERATORS_TYPE. x
33+ if latest_operators_type == IsNothing
34+ return print (io, string_tree (tree))
35+ elseif latest_operators_type == IsOperatorEnum
36+ latest_operators = LATEST_OPERATORS. x:: OperatorEnum
37+ return print (io, string_tree (tree, latest_operators))
38+ else
39+ latest_operators = LATEST_OPERATORS. x:: GenericOperatorEnum
40+ return print (io, string_tree (tree, latest_operators))
2841 end
2942end
43+ function (tree:: Node )(X; kws... )
44+ Base. depwarn (
45+ " The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead." ,
46+ :Node ,
47+ )
48+ latest_operators_type = LATEST_OPERATORS_TYPE. x
49+ if latest_operators_type == IsNothing
50+ error (" Please use the `tree(X, operators; kws...)` syntax instead." )
51+ elseif latest_operators_type == IsOperatorEnum
52+ latest_operators = LATEST_OPERATORS. x:: OperatorEnum
53+ return tree (X, latest_operators; kws... )
54+ else
55+ latest_operators = LATEST_OPERATORS. x:: GenericOperatorEnum
56+ return tree (X, latest_operators; kws... )
57+ end
58+ end
59+
60+ function _grad_evaluator (tree:: Node , X; kws... )
61+ Base. depwarn (
62+ " The `tree'(X; kws...)` syntax is deprecated. Use `tree'(X, operators; kws...)` instead." ,
63+ :Node ,
64+ )
65+ latest_operators_type = LATEST_OPERATORS_TYPE. x
66+ # return _grad_evaluator(tree, X, $operators; kws...)
67+ if latest_operators_type == IsNothing
68+ error (" Please use the `tree'(X, operators; kws...)` syntax instead." )
69+ elseif latest_operators_type == IsOperatorEnum
70+ latest_operators = LATEST_OPERATORS. x:: OperatorEnum
71+ return _grad_evaluator (tree, X, latest_operators; kws... )
72+ else
73+ error (" Gradients are not implemented for `GenericOperatorEnum`." )
74+ end
75+ end
76+
77+ function create_evaluation_helpers! (operators:: OperatorEnum )
78+ LATEST_OPERATORS. x = operators
79+ return LATEST_OPERATORS_TYPE. x = IsOperatorEnum
80+ end
3081
3182function create_evaluation_helpers! (operators:: GenericOperatorEnum )
32- @eval begin
33- Base. print (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
34- Base. show (io:: IO , tree:: Node ) = print (io, string_tree (tree, $ operators))
35-
36- function (tree:: Node )(X; kws... )
37- Base. depwarn (
38- " The `tree(X; kws...)` syntax is deprecated. Use `tree(X, operators; kws...)` instead." ,
39- :Node ,
40- )
41- return tree (X, $ operators; kws... )
42- end
43- function _grad_evaluator (tree:: Node , X; kws... )
44- return error (" Gradients are not implemented for `GenericOperatorEnum`." )
45- end
83+ LATEST_OPERATORS. x = operators
84+ return LATEST_OPERATORS_TYPE. x = IsGenericOperatorEnum
85+ end
86+ function lookup_op (@nospecialize (f), :: Val{degree} ) where {degree}
87+ mapping = degree == 1 ? LATEST_UNARY_OPERATOR_MAPPING : LATEST_BINARY_OPERATOR_MAPPING
88+ if ! haskey (mapping, f)
89+ error (
90+ " Convenience constructor for `Node` using operator `$(f) ` is out-of-date. " *
91+ " Please create an `OperatorEnum` (or `GenericOperatorEnum`) with " *
92+ " `define_helper_functions=true` and pass `$(f) `." ,
93+ )
4694 end
95+ return mapping[f]
4796end
4897
49- function _extend_unary_operator (f:: Symbol , op, type_requirements)
98+ function _extend_unary_operator (f:: Symbol , type_requirements)
5099 quote
51100 quote
52101 function $ ($ f)(l:: Node{T} ):: Node{T} where {T<: $ ($ type_requirements)}
53102 return if (l. degree == 0 && l. constant)
54103 Node (T; val= $ ($ f)(l. val:: T ))
55104 else
56- Node ($ ($ op), l)
105+ latest_op_idx = $ ($ lookup_op)($ ($ f), Val (1 ))
106+ Node (latest_op_idx, l)
57107 end
58108 end
59109 end
60110 end
61111end
62112
63- function _extend_binary_operator (f:: Symbol , op, type_requirements, build_converters)
113+ function _extend_binary_operator (f:: Symbol , type_requirements, build_converters)
64114 quote
65115 quote
66116 function $ ($ f)(l:: Node{T} , r:: Node{T} ) where {T<: $ ($ type_requirements)}
67117 if (l. degree == 0 && l. constant && r. degree == 0 && r. constant)
68118 Node (T; val= $ ($ f)(l. val:: T , r. val:: T ))
69119 else
70- Node ($ ($ op), l, r)
120+ latest_op_idx = $ ($ lookup_op)($ ($ f), Val (2 ))
121+ Node (latest_op_idx, l, r)
71122 end
72123 end
73124 function $ ($ f)(l:: Node{T} , r:: T ) where {T<: $ ($ type_requirements)}
74125 if l. degree == 0 && l. constant
75126 Node (T; val= $ ($ f)(l. val:: T , r))
76127 else
77- Node ($ ($ op), l, Node (T; val= r))
128+ latest_op_idx = $ ($ lookup_op)($ ($ f), Val (2 ))
129+ Node (latest_op_idx, l, Node (T; val= r))
78130 end
79131 end
80132 function $ ($ f)(l:: T , r:: Node{T} ) where {T<: $ ($ type_requirements)}
81133 if r. degree == 0 && r. constant
82134 Node (T; val= $ ($ f)(l, r. val:: T ))
83135 else
84- Node ($ ($ op), Node (T; val= l), r)
136+ latest_op_idx = $ ($ lookup_op)($ ($ f), Val (2 ))
137+ Node (latest_op_idx, Node (T; val= l), r)
85138 end
86139 end
87140 if $ ($ build_converters)
@@ -116,37 +169,62 @@ function _extend_binary_operator(f::Symbol, op, type_requirements, build_convert
116169end
117170
118171function _extend_operators (operators, skip_user_operators, __module__:: Module )
119- binary_ex = _extend_binary_operator (:f , :op , : type_requirements , :build_converters )
120- unary_ex = _extend_unary_operator (:f , :op , : type_requirements )
172+ binary_ex = _extend_binary_operator (:f , :type_requirements , :build_converters )
173+ unary_ex = _extend_unary_operator (:f , :type_requirements )
121174 return quote
122175 local type_requirements
123176 local build_converters
177+ local binary_exists
178+ local unary_exists
124179 if isa ($ operators, OperatorEnum)
125180 type_requirements = Number
126181 build_converters = true
182+ binary_exists = $ (ALREADY_DEFINED_BINARY_OPERATORS). operator_enum
183+ unary_exists = $ (ALREADY_DEFINED_UNARY_OPERATORS). operator_enum
127184 else
128185 type_requirements = Any
129186 build_converters = false
187+ binary_exists = $ (ALREADY_DEFINED_BINARY_OPERATORS). generic_operator_enum
188+ unary_exists = $ (ALREADY_DEFINED_UNARY_OPERATORS). generic_operator_enum
130189 end
131- for (op, f) in enumerate (map (Symbol, $ (operators). binops))
190+ # Trigger errors if operators are not yet defined:
191+ empty! ($ (LATEST_BINARY_OPERATOR_MAPPING))
192+ empty! ($ (LATEST_UNARY_OPERATOR_MAPPING))
193+ for (op, func) in enumerate ($ (operators). binops)
194+ local f = Symbol (func)
195+ local skip = false
132196 if isdefined (Base, f)
133197 f = :(Base.$ (f))
134198 elseif $ (skip_user_operators)
135- continue
199+ skip = true
136200 else
137201 f = :($ ($ __module__). $ (f))
138202 end
139- eval ($ binary_ex)
203+ $ (LATEST_BINARY_OPERATOR_MAPPING)[func] = op
204+ skip && continue
205+ # Avoid redefining methods:
206+ if ! haskey (unary_exists, func)
207+ eval ($ binary_ex)
208+ unary_exists[func] = true
209+ end
140210 end
141- for (op, f) in enumerate (map (Symbol, $ (operators). unaops))
211+ for (op, func) in enumerate ($ (operators). unaops)
212+ local f = Symbol (func)
213+ local skip = false
142214 if isdefined (Base, f)
143215 f = :(Base.$ (f))
144216 elseif $ (skip_user_operators)
145- continue
217+ skip = true
146218 else
147219 f = :($ ($ __module__). $ (f))
148220 end
149- eval ($ unary_ex)
221+ $ (LATEST_UNARY_OPERATOR_MAPPING)[func] = op
222+ skip && continue
223+ # Avoid redefining methods:
224+ if ! haskey (binary_exists, func)
225+ eval ($ unary_ex)
226+ binary_exists[func] = true
227+ end
150228 end
151229 end
152230end
@@ -162,14 +240,16 @@ apply this macro to the operator enum in the same module you have the operators
162240defined.
163241"""
164242macro extend_operators (operators)
165- ex = _extend_operators (esc ( operators) , false , __module__)
243+ ex = _extend_operators (operators, false , __module__)
166244 expected_type = AbstractOperatorEnum
167- quote
168- if ! isa ($ (esc (operators)), $ expected_type)
169- error (" You must pass an operator enum to `@extend_operators`." )
170- end
171- $ ex
172- end
245+ return esc (
246+ quote
247+ if ! isa ($ (operators), $ expected_type)
248+ error (" You must pass an operator enum to `@extend_operators`." )
249+ end
250+ $ ex
251+ end ,
252+ )
173253end
174254
175255"""
@@ -179,14 +259,16 @@ Similar to `@extend_operators`, but only extends operators already
179259defined in `Base`.
180260"""
181261macro extend_operators_base (operators)
182- ex = _extend_operators (esc ( operators) , true , __module__)
262+ ex = _extend_operators (operators, true , __module__)
183263 expected_type = AbstractOperatorEnum
184- quote
185- if ! isa ($ (esc (operators)), $ expected_type)
186- error (" You must pass an operator enum to `@extend_operators_base`." )
187- end
188- $ ex
189- end
264+ return esc (
265+ quote
266+ if ! isa ($ (operators), $ expected_type)
267+ error (" You must pass an operator enum to `@extend_operators_base`." )
268+ end
269+ $ ex
270+ end ,
271+ )
190272end
191273
192274"""
0 commit comments