@@ -177,7 +177,18 @@ function _extend_binary_operator(f::Symbol, type_requirements, build_converters)
177177 end
178178end
179179
180- function _extend_operators (operators, skip_user_operators, __module__:: Module )
180+ function _extend_operators (operators, skip_user_operators, kws, __module__:: Module )
181+ empty_old_operators =
182+ if length (kws) == 1 && :empty_old_operators in map (x -> x. args[1 ], kws)
183+ @assert kws[1 ]. head == :(= )
184+ kws[1 ]. args[2 ]
185+ elseif length (kws) > 0
186+ error (
187+ " You passed the keywords $(kws) , but only `empty_old_operators` is supported." ,
188+ )
189+ else
190+ true
191+ end
181192 binary_ex = _extend_binary_operator (:f , :type_requirements , :build_converters )
182193 unary_ex = _extend_unary_operator (:f , :type_requirements )
183194 return quote
@@ -196,9 +207,11 @@ function _extend_operators(operators, skip_user_operators, __module__::Module)
196207 binary_exists = $ (ALREADY_DEFINED_BINARY_OPERATORS). generic_operator_enum
197208 unary_exists = $ (ALREADY_DEFINED_UNARY_OPERATORS). generic_operator_enum
198209 end
199- # Trigger errors if operators are not yet defined:
200- empty! ($ (LATEST_BINARY_OPERATOR_MAPPING))
201- empty! ($ (LATEST_UNARY_OPERATOR_MAPPING))
210+ if $ (empty_old_operators)
211+ # Trigger errors if operators are not yet defined:
212+ empty! ($ (LATEST_BINARY_OPERATOR_MAPPING))
213+ empty! ($ (LATEST_UNARY_OPERATOR_MAPPING))
214+ end
202215 for (op, func) in enumerate ($ (operators). binops)
203216 local f = Symbol (func)
204217 local skip = false
@@ -239,7 +252,7 @@ function _extend_operators(operators, skip_user_operators, __module__::Module)
239252end
240253
241254"""
242- @extend_operators operators
255+ @extend_operators operators [kws...]
243256
244257Extends all operators defined in this operator enum to work on the
245258`Node` type. While by default this is already done for operators defined
@@ -248,8 +261,8 @@ this does not apply to the user-defined operators. Thus, to do so, you must
248261apply this macro to the operator enum in the same module you have the operators
249262defined.
250263"""
251- macro extend_operators (operators)
252- ex = _extend_operators (operators, false , __module__)
264+ macro extend_operators (operators, kws ... )
265+ ex = _extend_operators (operators, false , kws, __module__)
253266 expected_type = AbstractOperatorEnum
254267 return esc (
255268 quote
@@ -262,13 +275,13 @@ macro extend_operators(operators)
262275end
263276
264277"""
265- @extend_operators_base operators
278+ @extend_operators_base operators [kws...]
266279
267280Similar to `@extend_operators`, but only extends operators already
268281defined in `Base`.
269282"""
270- macro extend_operators_base (operators)
271- ex = _extend_operators (operators, true , __module__)
283+ macro extend_operators_base (operators, kws ... )
284+ ex = _extend_operators (operators, true , kws, __module__)
272285 expected_type = AbstractOperatorEnum
273286 return esc (
274287 quote
@@ -281,7 +294,9 @@ macro extend_operators_base(operators)
281294end
282295
283296"""
284- OperatorEnum(; binary_operators=[], unary_operators=[], enable_autodiff::Bool=false, define_helper_functions::Bool=true)
297+ OperatorEnum(; binary_operators=[], unary_operators=[],
298+ enable_autodiff::Bool=false, define_helper_functions::Bool=true,
299+ empty_old_operators::Bool=true)
285300
286301Construct an `OperatorEnum` object, defining the possible expressions. This will also
287302redefine operators for `Node` types, as well as `show`, `print`, and `(::Node)(X)`.
@@ -296,12 +311,14 @@ It will automatically compute derivatives with `Zygote.jl`.
296311- `define_helper_functions::Bool=true`: Whether to define helper functions for creating
297312 and evaluating node types. Turn this off when doing precompilation. Note that these
298313 are *not* needed for the package to work; they are purely for convenience.
314+ - `empty_old_operators::Bool=true`: Whether to clear the old operators.
299315"""
300316function OperatorEnum (;
301317 binary_operators= [],
302318 unary_operators= [],
303319 enable_autodiff:: Bool = false ,
304320 define_helper_functions:: Bool = true ,
321+ empty_old_operators:: Bool = true ,
305322)
306323 @assert length (binary_operators) > 0 || length (unary_operators) > 0
307324
@@ -325,15 +342,16 @@ function OperatorEnum(;
325342 )
326343
327344 if define_helper_functions
328- @extend_operators_base operators
345+ @extend_operators_base operators empty_old_operators = empty_old_operators
329346 create_evaluation_helpers! (operators)
330347 end
331348
332349 return operators
333350end
334351
335352"""
336- GenericOperatorEnum(; binary_operators=[], unary_operators=[], define_helper_functions::Bool=true)
353+ GenericOperatorEnum(; binary_operators=[], unary_operators=[],
354+ define_helper_functions::Bool=true, empty_old_operators::Bool=true)
337355
338356Construct a `GenericOperatorEnum` object, defining possible expressions.
339357Unlike `OperatorEnum`, this enum one will work arbitrary operators and data types.
@@ -348,9 +366,13 @@ and `(::Node)(X)`.
348366- `define_helper_functions::Bool=true`: Whether to define helper functions for creating
349367 and evaluating node types. Turn this off when doing precompilation. Note that these
350368 are *not* needed for the package to work; they are purely for convenience.
369+ - `empty_old_operators::Bool=true`: Whether to clear the old operators.
351370"""
352371function GenericOperatorEnum (;
353- binary_operators= [], unary_operators= [], define_helper_functions:: Bool = true
372+ binary_operators= [],
373+ unary_operators= [],
374+ define_helper_functions:: Bool = true ,
375+ empty_old_operators:: Bool = true ,
354376)
355377 @assert length (binary_operators) > 0 || length (unary_operators) > 0
356378
@@ -360,7 +382,7 @@ function GenericOperatorEnum(;
360382 operators = GenericOperatorEnum (binary_operators, unary_operators)
361383
362384 if define_helper_functions
363- @extend_operators_base operators
385+ @extend_operators_base operators empty_old_operators = empty_old_operators
364386 create_evaluation_helpers! (operators)
365387 end
366388
0 commit comments