Skip to content

Commit c051eee

Browse files
authored
Merge pull request #8 from SymbolicML/safer-operator-extending
Safer way of extending user-defined operators
2 parents 51eb902 + 34d1018 commit c051eee

File tree

11 files changed

+309
-129
lines changed

11 files changed

+309
-129
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "0.3.2"
4+
version = "0.4.0"
55

66
[deps]
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

README.md

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -159,19 +159,28 @@ x1 = Node(String; feature=1)
159159
This node, will be used to index input data (whatever it may be) with either `data[feature]` (1D abstract arrays) or `selectdim(data, 1, feature)` (ND abstract arrays). Let's now define some operators to use:
160160

161161
```julia
162-
my_string_func(x::String) = "Hello $x"
162+
my_string_func(x::String) = "ello $x"
163163

164164
operators = GenericOperatorEnum(;
165165
binary_operators=[*],
166-
unary_operators=[my_string_func],
167-
extend_user_operators=true)
166+
unary_operators=[my_string_func]
167+
)
168+
```
169+
170+
Now, let's extend our operators to work with the
171+
expression types used by `DynamicExpressions.jl`:
172+
173+
```julia
174+
@extend_operators operators
168175
```
169176

170177
Now, let's create an expression:
171178

172179
```julia
173-
tree = x1 * " World!"
174-
tree(["Hello", "Me?"])
180+
tree = "H" * my_string_func(x1)
181+
# ^ `(H * my_string_func(x1))`
182+
183+
tree(["World!", "Me?"])
175184
# Hello World!
176185
```
177186

@@ -202,7 +211,8 @@ vec_add(x, y) = x .+ y
202211
vec_square(x) = x .* x
203212

204213
# Set up an operator enum:
205-
operators = GenericOperatorEnum(;binary_operators=[vec_add], unary_operators=[vec_square], extend_user_operators=true)
214+
operators = GenericOperatorEnum(;binary_operators=[vec_add], unary_operators=[vec_square])
215+
@extend_operators operators
206216

207217
# Construct the expression:
208218
tree = vec_add(vec_add(vec_square(x1), c2), c1)

docs/src/types.md

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,35 @@ OperatorEnum
1313
Construct this operator specification as follows:
1414

1515
```@docs
16-
OperatorEnum(; binary_operators, unary_operators, enable_autodiff)
16+
OperatorEnum(; binary_operators=[], unary_operators=[], enable_autodiff::Bool=false, define_helper_functions::Bool=true)
1717
```
1818

1919
This is just for scalar real operators. However, you can use
2020
the following for more general operators:
2121

2222
```@docs
23-
GenericOperatorEnum(; binary_operators=[], unary_operators=[], extend_user_operators::Bool=false)
23+
GenericOperatorEnum(; binary_operators=[], unary_operators=[], define_helper_functions::Bool=true)
2424
```
2525

26+
By default, these operators will define helper functions for constructing trees,
27+
so that you can write `Node(;feature=1) + Node(;feature=2)` instead of
28+
`Node(1, Node(;feature=1), Node(;feature=2))` (assuming `+` is the first operator).
29+
You can turn this off with `define_helper_functions=false`.
30+
31+
For other operators *not* found in `Base`, including user-defined functions, you may
32+
use the `@extend_operators` macro:
33+
34+
```@docs
35+
@extend_operators operators
36+
```
37+
38+
This will extend the operators you have passed to work with `Node` types, so that
39+
it is easier to construct expression trees.
40+
41+
Note that you are free to use the `Node` constructors directly.
42+
This is a more robust approach, and should be used when creating libraries
43+
which use `DynamicExpressions.jl`.
44+
2645
## Equations
2746

2847
Equations are specified as binary trees with the `Node` type, defined

src/DynamicExpressions.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ using Reexport
2424
get_constants,
2525
set_constants
2626
@reexport import .OperatorEnumModule: AbstractOperatorEnum
27-
@reexport import .OperatorEnumConstructionModule: OperatorEnum, GenericOperatorEnum
27+
@reexport import .OperatorEnumConstructionModule:
28+
OperatorEnum, GenericOperatorEnum, @extend_operators
2829
@reexport import .EvaluateEquationModule: eval_tree_array, differentiable_eval_tree_array
2930
@reexport import .EvaluateEquationDerivativeModule:
3031
eval_diff_tree_array, eval_grad_tree_array

0 commit comments

Comments
 (0)