@@ -79,15 +79,19 @@ function _grad_evaluator(tree::Node, X; kws...)
7979 end
8080end
8181
82+ function set_default_variable_names! (variable_names:: Vector{String} )
83+ return LATEST_VARIABLE_NAMES. x = variable_names
84+ end
85+
8286function create_evaluation_helpers! (operators:: OperatorEnum )
8387 LATEST_OPERATORS. x = operators
8488 return LATEST_OPERATORS_TYPE. x = IsOperatorEnum
8589end
86-
8790function create_evaluation_helpers! (operators:: GenericOperatorEnum )
8891 LATEST_OPERATORS. x = operators
8992 return LATEST_OPERATORS_TYPE. x = IsGenericOperatorEnum
9093end
94+
9195function lookup_op (@nospecialize (f), :: Val{degree} ) where {degree}
9296 mapping = degree == 1 ? LATEST_UNARY_OPERATOR_MAPPING : LATEST_BINARY_OPERATOR_MAPPING
9397 if ! haskey (mapping, f)
@@ -173,15 +177,26 @@ function _extend_binary_operator(f::Symbol, type_requirements, build_converters)
173177 end
174178end
175179
176- function _extend_operators (operators, skip_user_operators, __module__:: Module )
180+ function _extend_operators (operators, skip_user_operators, kws, __module__:: Module )
181+ empty_old_operators =
182+ if length (kws) == 1 && :empty_old_operators in map (x -> x. args[1 ], kws)
183+ @assert kws[1 ]. head == :(= )
184+ kws[1 ]. args[2 ]
185+ elseif length (kws) > 0
186+ error (
187+ " You passed the keywords $(kws) , but only `empty_old_operators` is supported." ,
188+ )
189+ else
190+ true
191+ end
177192 binary_ex = _extend_binary_operator (:f , :type_requirements , :build_converters )
178193 unary_ex = _extend_unary_operator (:f , :type_requirements )
179194 return quote
180195 local type_requirements
181196 local build_converters
182197 local binary_exists
183198 local unary_exists
184- if isa ($ operators, OperatorEnum)
199+ if isa ($ operators, $ OperatorEnum)
185200 type_requirements = Number
186201 build_converters = true
187202 binary_exists = $ (ALREADY_DEFINED_BINARY_OPERATORS). operator_enum
@@ -192,9 +207,11 @@ function _extend_operators(operators, skip_user_operators, __module__::Module)
192207 binary_exists = $ (ALREADY_DEFINED_BINARY_OPERATORS). generic_operator_enum
193208 unary_exists = $ (ALREADY_DEFINED_UNARY_OPERATORS). generic_operator_enum
194209 end
195- # Trigger errors if operators are not yet defined:
196- empty! ($ (LATEST_BINARY_OPERATOR_MAPPING))
197- empty! ($ (LATEST_UNARY_OPERATOR_MAPPING))
210+ if $ (empty_old_operators)
211+ # Trigger errors if operators are not yet defined:
212+ empty! ($ (LATEST_BINARY_OPERATOR_MAPPING))
213+ empty! ($ (LATEST_UNARY_OPERATOR_MAPPING))
214+ end
198215 for (op, func) in enumerate ($ (operators). binops)
199216 local f = Symbol (func)
200217 local skip = false
@@ -235,7 +252,7 @@ function _extend_operators(operators, skip_user_operators, __module__::Module)
235252end
236253
237254"""
238- @extend_operators operators
255+ @extend_operators operators [kws...]
239256
240257Extends all operators defined in this operator enum to work on the
241258`Node` type. While by default this is already done for operators defined
@@ -244,8 +261,8 @@ this does not apply to the user-defined operators. Thus, to do so, you must
244261apply this macro to the operator enum in the same module you have the operators
245262defined.
246263"""
247- macro extend_operators (operators)
248- ex = _extend_operators (operators, false , __module__)
264+ macro extend_operators (operators, kws ... )
265+ ex = _extend_operators (operators, false , kws, __module__)
249266 expected_type = AbstractOperatorEnum
250267 return esc (
251268 quote
@@ -258,13 +275,13 @@ macro extend_operators(operators)
258275end
259276
260277"""
261- @extend_operators_base operators
278+ @extend_operators_base operators [kws...]
262279
263280Similar to `@extend_operators`, but only extends operators already
264281defined in `Base`.
265282"""
266- macro extend_operators_base (operators)
267- ex = _extend_operators (operators, true , __module__)
283+ macro extend_operators_base (operators, kws ... )
284+ ex = _extend_operators (operators, true , kws, __module__)
268285 expected_type = AbstractOperatorEnum
269286 return esc (
270287 quote
@@ -277,7 +294,9 @@ macro extend_operators_base(operators)
277294end
278295
279296"""
280- OperatorEnum(; binary_operators=[], unary_operators=[], enable_autodiff::Bool=false, define_helper_functions::Bool=true)
297+ OperatorEnum(; binary_operators=[], unary_operators=[],
298+ enable_autodiff::Bool=false, define_helper_functions::Bool=true,
299+ empty_old_operators::Bool=true)
281300
282301Construct an `OperatorEnum` object, defining the possible expressions. This will also
283302redefine operators for `Node` types, as well as `show`, `print`, and `(::Node)(X)`.
@@ -292,12 +311,14 @@ It will automatically compute derivatives with `Zygote.jl`.
292311- `define_helper_functions::Bool=true`: Whether to define helper functions for creating
293312 and evaluating node types. Turn this off when doing precompilation. Note that these
294313 are *not* needed for the package to work; they are purely for convenience.
314+ - `empty_old_operators::Bool=true`: Whether to clear the old operators.
295315"""
296316function OperatorEnum (;
297317 binary_operators= [],
298318 unary_operators= [],
299319 enable_autodiff:: Bool = false ,
300320 define_helper_functions:: Bool = true ,
321+ empty_old_operators:: Bool = true ,
301322)
302323 @assert length (binary_operators) > 0 || length (unary_operators) > 0
303324
@@ -321,15 +342,16 @@ function OperatorEnum(;
321342 )
322343
323344 if define_helper_functions
324- @extend_operators_base operators
345+ @extend_operators_base operators empty_old_operators = empty_old_operators
325346 create_evaluation_helpers! (operators)
326347 end
327348
328349 return operators
329350end
330351
331352"""
332- GenericOperatorEnum(; binary_operators=[], unary_operators=[], define_helper_functions::Bool=true)
353+ GenericOperatorEnum(; binary_operators=[], unary_operators=[],
354+ define_helper_functions::Bool=true, empty_old_operators::Bool=true)
333355
334356Construct a `GenericOperatorEnum` object, defining possible expressions.
335357Unlike `OperatorEnum`, this enum one will work arbitrary operators and data types.
@@ -344,9 +366,13 @@ and `(::Node)(X)`.
344366- `define_helper_functions::Bool=true`: Whether to define helper functions for creating
345367 and evaluating node types. Turn this off when doing precompilation. Note that these
346368 are *not* needed for the package to work; they are purely for convenience.
369+ - `empty_old_operators::Bool=true`: Whether to clear the old operators.
347370"""
348371function GenericOperatorEnum (;
349- binary_operators= [], unary_operators= [], define_helper_functions:: Bool = true
372+ binary_operators= [],
373+ unary_operators= [],
374+ define_helper_functions:: Bool = true ,
375+ empty_old_operators:: Bool = true ,
350376)
351377 @assert length (binary_operators) > 0 || length (unary_operators) > 0
352378
@@ -356,7 +382,7 @@ function GenericOperatorEnum(;
356382 operators = GenericOperatorEnum (binary_operators, unary_operators)
357383
358384 if define_helper_functions
359- @extend_operators_base operators
385+ @extend_operators_base operators empty_old_operators = empty_old_operators
360386 create_evaluation_helpers! (operators)
361387 end
362388
0 commit comments