Skip to content

Commit 2c8ce96

Browse files
committed
Allow disabling behavior of emptying old operators
1 parent 89bd5d0 commit 2c8ce96

File tree

2 files changed

+44
-15
lines changed

2 files changed

+44
-15
lines changed

src/OperatorEnumConstruction.jl

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,18 @@ function _extend_binary_operator(f::Symbol, type_requirements, build_converters)
177177
end
178178
end
179179

180-
function _extend_operators(operators, skip_user_operators, __module__::Module)
180+
function _extend_operators(operators, skip_user_operators, kws, __module__::Module)
181+
empty_old_operators =
182+
if length(kws) == 1 && :empty_old_operators in map(x -> x.args[1], kws)
183+
@assert kws[1].head == :(=)
184+
kws[1].args[2]
185+
elseif length(kws) > 0
186+
error(
187+
"You passed the keywords $(kws), but only `empty_old_operators` is supported.",
188+
)
189+
else
190+
true
191+
end
181192
binary_ex = _extend_binary_operator(:f, :type_requirements, :build_converters)
182193
unary_ex = _extend_unary_operator(:f, :type_requirements)
183194
return quote
@@ -196,9 +207,11 @@ function _extend_operators(operators, skip_user_operators, __module__::Module)
196207
binary_exists = $(ALREADY_DEFINED_BINARY_OPERATORS).generic_operator_enum
197208
unary_exists = $(ALREADY_DEFINED_UNARY_OPERATORS).generic_operator_enum
198209
end
199-
# Trigger errors if operators are not yet defined:
200-
empty!($(LATEST_BINARY_OPERATOR_MAPPING))
201-
empty!($(LATEST_UNARY_OPERATOR_MAPPING))
210+
if $(empty_old_operators)
211+
# Trigger errors if operators are not yet defined:
212+
empty!($(LATEST_BINARY_OPERATOR_MAPPING))
213+
empty!($(LATEST_UNARY_OPERATOR_MAPPING))
214+
end
202215
for (op, func) in enumerate($(operators).binops)
203216
local f = Symbol(func)
204217
local skip = false
@@ -239,7 +252,7 @@ function _extend_operators(operators, skip_user_operators, __module__::Module)
239252
end
240253

241254
"""
242-
@extend_operators operators
255+
@extend_operators operators [kws...]
243256
244257
Extends all operators defined in this operator enum to work on the
245258
`Node` type. While by default this is already done for operators defined
@@ -248,8 +261,8 @@ this does not apply to the user-defined operators. Thus, to do so, you must
248261
apply this macro to the operator enum in the same module you have the operators
249262
defined.
250263
"""
251-
macro extend_operators(operators)
252-
ex = _extend_operators(operators, false, __module__)
264+
macro extend_operators(operators, kws...)
265+
ex = _extend_operators(operators, false, kws, __module__)
253266
expected_type = AbstractOperatorEnum
254267
return esc(
255268
quote
@@ -262,13 +275,13 @@ macro extend_operators(operators)
262275
end
263276

264277
"""
265-
@extend_operators_base operators
278+
@extend_operators_base operators [kws...]
266279
267280
Similar to `@extend_operators`, but only extends operators already
268281
defined in `Base`.
269282
"""
270-
macro extend_operators_base(operators)
271-
ex = _extend_operators(operators, true, __module__)
283+
macro extend_operators_base(operators, kws...)
284+
ex = _extend_operators(operators, true, kws, __module__)
272285
expected_type = AbstractOperatorEnum
273286
return esc(
274287
quote
@@ -281,7 +294,9 @@ macro extend_operators_base(operators)
281294
end
282295

283296
"""
284-
OperatorEnum(; binary_operators=[], unary_operators=[], enable_autodiff::Bool=false, define_helper_functions::Bool=true)
297+
OperatorEnum(; binary_operators=[], unary_operators=[],
298+
enable_autodiff::Bool=false, define_helper_functions::Bool=true,
299+
empty_old_operators::Bool=true)
285300
286301
Construct an `OperatorEnum` object, defining the possible expressions. This will also
287302
redefine operators for `Node` types, as well as `show`, `print`, and `(::Node)(X)`.
@@ -296,12 +311,14 @@ It will automatically compute derivatives with `Zygote.jl`.
296311
- `define_helper_functions::Bool=true`: Whether to define helper functions for creating
297312
and evaluating node types. Turn this off when doing precompilation. Note that these
298313
are *not* needed for the package to work; they are purely for convenience.
314+
- `empty_old_operators::Bool=true`: Whether to clear the old operators.
299315
"""
300316
function OperatorEnum(;
301317
binary_operators=[],
302318
unary_operators=[],
303319
enable_autodiff::Bool=false,
304320
define_helper_functions::Bool=true,
321+
empty_old_operators::Bool=true,
305322
)
306323
@assert length(binary_operators) > 0 || length(unary_operators) > 0
307324

@@ -325,15 +342,16 @@ function OperatorEnum(;
325342
)
326343

327344
if define_helper_functions
328-
@extend_operators_base operators
345+
@extend_operators_base operators empty_old_operators = empty_old_operators
329346
create_evaluation_helpers!(operators)
330347
end
331348

332349
return operators
333350
end
334351

335352
"""
336-
GenericOperatorEnum(; binary_operators=[], unary_operators=[], define_helper_functions::Bool=true)
353+
GenericOperatorEnum(; binary_operators=[], unary_operators=[],
354+
define_helper_functions::Bool=true, empty_old_operators::Bool=true)
337355
338356
Construct a `GenericOperatorEnum` object, defining possible expressions.
339357
Unlike `OperatorEnum`, this enum one will work arbitrary operators and data types.
@@ -348,9 +366,13 @@ and `(::Node)(X)`.
348366
- `define_helper_functions::Bool=true`: Whether to define helper functions for creating
349367
and evaluating node types. Turn this off when doing precompilation. Note that these
350368
are *not* needed for the package to work; they are purely for convenience.
369+
- `empty_old_operators::Bool=true`: Whether to clear the old operators.
351370
"""
352371
function GenericOperatorEnum(;
353-
binary_operators=[], unary_operators=[], define_helper_functions::Bool=true
372+
binary_operators=[],
373+
unary_operators=[],
374+
define_helper_functions::Bool=true,
375+
empty_old_operators::Bool=true,
354376
)
355377
@assert length(binary_operators) > 0 || length(unary_operators) > 0
356378

@@ -360,7 +382,7 @@ function GenericOperatorEnum(;
360382
operators = GenericOperatorEnum(binary_operators, unary_operators)
361383

362384
if define_helper_functions
363-
@extend_operators_base operators
385+
@extend_operators_base operators empty_old_operators = empty_old_operators
364386
create_evaluation_helpers!(operators)
365387
end
366388

test/test_safe_helpers.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,10 @@ operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos,
2929

3030
# Breaks:
3131
@test_throws ErrorException _square(x1 + x2 / x3) * x2 + 0.5
32+
33+
# We can also turn this behavior off:
34+
operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos, sin])
35+
operators = OperatorEnum(;
36+
binary_operators=[+, -, *, /], unary_operators=[cos, tan], empty_old_operators=false
37+
)
38+
@test tan(x1) == sin(x1)

0 commit comments

Comments
 (0)