Skip to content

Commit 28dcd2d

Browse files
authored
Merge pull request #47 from SymbolicML/construction-error
Allow operator aliases
2 parents cad5b8b + 580ed0f commit 28dcd2d

File tree

4 files changed

+57
-19
lines changed

4 files changed

+57
-19
lines changed

src/DynamicExpressions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import Reexport: @reexport
2727
set_constants!
2828
@reexport import .OperatorEnumModule: AbstractOperatorEnum
2929
@reexport import .OperatorEnumConstructionModule:
30-
OperatorEnum, GenericOperatorEnum, @extend_operators
30+
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!
3131
@reexport import .EvaluateEquationModule: eval_tree_array, differentiable_eval_tree_array
3232
@reexport import .EvaluateEquationDerivativeModule:
3333
eval_diff_tree_array, eval_grad_tree_array

src/OperatorEnumConstruction.jl

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,19 @@ function _grad_evaluator(tree::Node, X; kws...)
7979
end
8080
end
8181

82+
function set_default_variable_names!(variable_names::Vector{String})
83+
return LATEST_VARIABLE_NAMES.x = variable_names
84+
end
85+
8286
function create_evaluation_helpers!(operators::OperatorEnum)
8387
LATEST_OPERATORS.x = operators
8488
return LATEST_OPERATORS_TYPE.x = IsOperatorEnum
8589
end
86-
8790
function create_evaluation_helpers!(operators::GenericOperatorEnum)
8891
LATEST_OPERATORS.x = operators
8992
return LATEST_OPERATORS_TYPE.x = IsGenericOperatorEnum
9093
end
94+
9195
function lookup_op(@nospecialize(f), ::Val{degree}) where {degree}
9296
mapping = degree == 1 ? LATEST_UNARY_OPERATOR_MAPPING : LATEST_BINARY_OPERATOR_MAPPING
9397
if !haskey(mapping, f)
@@ -173,15 +177,26 @@ function _extend_binary_operator(f::Symbol, type_requirements, build_converters)
173177
end
174178
end
175179

176-
function _extend_operators(operators, skip_user_operators, __module__::Module)
180+
function _extend_operators(operators, skip_user_operators, kws, __module__::Module)
181+
empty_old_operators =
182+
if length(kws) == 1 && :empty_old_operators in map(x -> x.args[1], kws)
183+
@assert kws[1].head == :(=)
184+
kws[1].args[2]
185+
elseif length(kws) > 0
186+
error(
187+
"You passed the keywords $(kws), but only `empty_old_operators` is supported.",
188+
)
189+
else
190+
true
191+
end
177192
binary_ex = _extend_binary_operator(:f, :type_requirements, :build_converters)
178193
unary_ex = _extend_unary_operator(:f, :type_requirements)
179194
return quote
180195
local type_requirements
181196
local build_converters
182197
local binary_exists
183198
local unary_exists
184-
if isa($operators, OperatorEnum)
199+
if isa($operators, $OperatorEnum)
185200
type_requirements = Number
186201
build_converters = true
187202
binary_exists = $(ALREADY_DEFINED_BINARY_OPERATORS).operator_enum
@@ -192,9 +207,11 @@ function _extend_operators(operators, skip_user_operators, __module__::Module)
192207
binary_exists = $(ALREADY_DEFINED_BINARY_OPERATORS).generic_operator_enum
193208
unary_exists = $(ALREADY_DEFINED_UNARY_OPERATORS).generic_operator_enum
194209
end
195-
# Trigger errors if operators are not yet defined:
196-
empty!($(LATEST_BINARY_OPERATOR_MAPPING))
197-
empty!($(LATEST_UNARY_OPERATOR_MAPPING))
210+
if $(empty_old_operators)
211+
# Trigger errors if operators are not yet defined:
212+
empty!($(LATEST_BINARY_OPERATOR_MAPPING))
213+
empty!($(LATEST_UNARY_OPERATOR_MAPPING))
214+
end
198215
for (op, func) in enumerate($(operators).binops)
199216
local f = Symbol(func)
200217
local skip = false
@@ -235,7 +252,7 @@ function _extend_operators(operators, skip_user_operators, __module__::Module)
235252
end
236253

237254
"""
238-
@extend_operators operators
255+
@extend_operators operators [kws...]
239256
240257
Extends all operators defined in this operator enum to work on the
241258
`Node` type. While by default this is already done for operators defined
@@ -244,8 +261,8 @@ this does not apply to the user-defined operators. Thus, to do so, you must
244261
apply this macro to the operator enum in the same module you have the operators
245262
defined.
246263
"""
247-
macro extend_operators(operators)
248-
ex = _extend_operators(operators, false, __module__)
264+
macro extend_operators(operators, kws...)
265+
ex = _extend_operators(operators, false, kws, __module__)
249266
expected_type = AbstractOperatorEnum
250267
return esc(
251268
quote
@@ -258,13 +275,13 @@ macro extend_operators(operators)
258275
end
259276

260277
"""
261-
@extend_operators_base operators
278+
@extend_operators_base operators [kws...]
262279
263280
Similar to `@extend_operators`, but only extends operators already
264281
defined in `Base`.
265282
"""
266-
macro extend_operators_base(operators)
267-
ex = _extend_operators(operators, true, __module__)
283+
macro extend_operators_base(operators, kws...)
284+
ex = _extend_operators(operators, true, kws, __module__)
268285
expected_type = AbstractOperatorEnum
269286
return esc(
270287
quote
@@ -277,7 +294,9 @@ macro extend_operators_base(operators)
277294
end
278295

279296
"""
280-
OperatorEnum(; binary_operators=[], unary_operators=[], enable_autodiff::Bool=false, define_helper_functions::Bool=true)
297+
OperatorEnum(; binary_operators=[], unary_operators=[],
298+
enable_autodiff::Bool=false, define_helper_functions::Bool=true,
299+
empty_old_operators::Bool=true)
281300
282301
Construct an `OperatorEnum` object, defining the possible expressions. This will also
283302
redefine operators for `Node` types, as well as `show`, `print`, and `(::Node)(X)`.
@@ -292,12 +311,14 @@ It will automatically compute derivatives with `Zygote.jl`.
292311
- `define_helper_functions::Bool=true`: Whether to define helper functions for creating
293312
and evaluating node types. Turn this off when doing precompilation. Note that these
294313
are *not* needed for the package to work; they are purely for convenience.
314+
- `empty_old_operators::Bool=true`: Whether to clear the old operators.
295315
"""
296316
function OperatorEnum(;
297317
binary_operators=[],
298318
unary_operators=[],
299319
enable_autodiff::Bool=false,
300320
define_helper_functions::Bool=true,
321+
empty_old_operators::Bool=true,
301322
)
302323
@assert length(binary_operators) > 0 || length(unary_operators) > 0
303324

@@ -321,15 +342,16 @@ function OperatorEnum(;
321342
)
322343

323344
if define_helper_functions
324-
@extend_operators_base operators
345+
@extend_operators_base operators empty_old_operators = empty_old_operators
325346
create_evaluation_helpers!(operators)
326347
end
327348

328349
return operators
329350
end
330351

331352
"""
332-
GenericOperatorEnum(; binary_operators=[], unary_operators=[], define_helper_functions::Bool=true)
353+
GenericOperatorEnum(; binary_operators=[], unary_operators=[],
354+
define_helper_functions::Bool=true, empty_old_operators::Bool=true)
333355
334356
Construct a `GenericOperatorEnum` object, defining possible expressions.
335357
Unlike `OperatorEnum`, this enum one will work arbitrary operators and data types.
@@ -344,9 +366,13 @@ and `(::Node)(X)`.
344366
- `define_helper_functions::Bool=true`: Whether to define helper functions for creating
345367
and evaluating node types. Turn this off when doing precompilation. Note that these
346368
are *not* needed for the package to work; they are purely for convenience.
369+
- `empty_old_operators::Bool=true`: Whether to clear the old operators.
347370
"""
348371
function GenericOperatorEnum(;
349-
binary_operators=[], unary_operators=[], define_helper_functions::Bool=true
372+
binary_operators=[],
373+
unary_operators=[],
374+
define_helper_functions::Bool=true,
375+
empty_old_operators::Bool=true,
350376
)
351377
@assert length(binary_operators) > 0 || length(unary_operators) > 0
352378

@@ -356,7 +382,7 @@ function GenericOperatorEnum(;
356382
operators = GenericOperatorEnum(binary_operators, unary_operators)
357383

358384
if define_helper_functions
359-
@extend_operators_base operators
385+
@extend_operators_base operators empty_old_operators = empty_old_operators
360386
create_evaluation_helpers!(operators)
361387
end
362388

test/test_print.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ end
111111
empty!(DynamicExpressions.OperatorEnumConstructionModule.LATEST_VARIABLE_NAMES.x)
112112
@test string(tree) == "((x1 * x2) + x3)"
113113
# Check if we can pass the wrong number of variable names:
114-
DynamicExpressions.OperatorEnumConstructionModule.LATEST_VARIABLE_NAMES.x = ["k1"]
114+
set_default_variable_names!(["k1"])
115115
@test string(tree) == "((k1 * x2) + x3)"
116116
empty!(DynamicExpressions.OperatorEnumConstructionModule.LATEST_VARIABLE_NAMES.x)
117117
end

test/test_safe_helpers.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,15 @@ operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos,
2929

3030
# Breaks:
3131
@test_throws ErrorException _square(x1 + x2 / x3) * x2 + 0.5
32+
33+
# We can also turn this behavior off:
34+
operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos, sin])
35+
operators = OperatorEnum(;
36+
binary_operators=[+, -, *, /], unary_operators=[cos, tan], empty_old_operators=false
37+
)
38+
@test tan(x1) == sin(x1)
39+
40+
# Should catch errors in kws:
41+
@test_throws LoadError begin
42+
@eval @extend_operators operators empty_old_operators_bad_kw = true
43+
end

0 commit comments

Comments
 (0)