Skip to content

Commit 1d3ed78

Browse files
committed
Test explicit calling syntax
1 parent ee25be7 commit 1d3ed78

File tree

6 files changed

+48
-18
lines changed

6 files changed

+48
-18
lines changed

test/test_custom_operators.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ tree = op1(op2(x1, x2), op3(x1))
2121
@test repr(tree) == "op1(op2(x1, x2), op3(x1))"
2222
# Test evaluation:
2323
X = randn(MersenneTwister(0), Float32, 2, 10);
24-
@test tree(X) ((x1, x2) -> op1(op2(x1, x2), op3(x1))).(X[1, :], X[2, :])
24+
@test tree(X, operators) ((x1, x2) -> op1(op2(x1, x2), op3(x1))).(X[1, :], X[2, :])
2525

2626
# Now, test that we can work with operators defined in modules
2727
module A
@@ -47,7 +47,7 @@ function create_and_eval_tree()
4747
tree = my_func_a(my_func_a(x2, 0.2), my_func_b(x1))
4848
func = (x1, x2) -> my_func_a(my_func_a(x2, 0.2), my_func_b(x1))
4949
X = randn(MersenneTwister(0), 2, 20)
50-
return tree(X), func.(X[1, :], X[2, :])
50+
return tree(X, operators), func.(X[1, :], X[2, :])
5151
end
5252

5353
end
@@ -74,4 +74,7 @@ c1 = Node(Float64; val=0.2)
7474
tree = my_func_c(my_func_c(x2, 0.2), my_func_d(x1))
7575
func = (x1, x2) -> my_func_c(my_func_c(x2, 0.2), my_func_d(x1))
7676
X = randn(MersenneTwister(0), 2, 20)
77+
@test tree(X, operators) func.(X[1, :], X[2, :])
78+
79+
# Test deprecated implicit syntax:
7780
@test tree(X) func.(X[1, :], X[2, :])

test/test_derivatives.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,15 @@ for type in [Float16, Float32, Float64], turbo in [true, false]
8888
i in 1:nfeatures
8989
],
9090
)'
91-
predicted_grad3 = tree'(X)
91+
predicted_grad3 = tree'(X, operators; turbo=turbo)
92+
# Test deprecated syntax:
93+
predicted_grad4 = tree'(X; turbo=turbo)
9294

9395
# Print largest difference between predicted_grad, true_grad:
9496
@test array_test(predicted_grad, true_grad)
9597
@test array_test(predicted_grad2, true_grad)
9698
@test array_test(predicted_grad3, true_grad)
99+
@test array_test(predicted_grad4, true_grad)
97100

98101
# Make sure that the array_test actually works:
99102
@test !array_test(predicted_grad .* 0, true_grad)

test/test_error_handling.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,27 @@
11
using DynamicExpressions
22
using Test
33

4+
# Before defining OperatorEnum, calling the implicit (deprecated)
5+
# syntax should fail:
6+
tree = Node(; feature=1)
7+
try
8+
tree([1.0 2.0]')
9+
@test false
10+
catch e
11+
@test isa(e, ErrorException)
12+
expected_error_msg = "The `tree(X; kws...)` syntax is deprecated"
13+
@test occursin(expected_error_msg, e.msg)
14+
end
15+
16+
try
17+
tree'([1.0 2.0]')
18+
@test false
19+
catch e
20+
@test isa(e, ErrorException)
21+
expected_error_msg = "The `tree'(X; kws...)` syntax is deprecated"
22+
@test occursin(expected_error_msg, e.msg)
23+
end
24+
425
# Test that we generate errors:
526
baseT = Float64
627
T = Union{baseT,Vector{baseT},Matrix{baseT}}
@@ -33,11 +54,11 @@ output, flag = eval_tree_array(
3354

3455
# Default is to catch errors:
3556
try
36-
tree([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
57+
tree([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], operators)
3758
@test false
3859
catch e
3960
@test isa(e, ErrorException)
4061
end
4162

4263
# But can be overrided:
43-
output = tree([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; throw_errors=false)
64+
output = tree([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], operators; throw_errors=false)

test/test_evaluation.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ for turbo in [false, true],
6969

7070
zero_tolerance = (T <: Union{Float16,Complex} ? 1e-4 : 1e-6)
7171
@test all(abs.(test_y .- true_y) / N .< zero_tolerance)
72+
73+
test_y_helper = tree(X, operators; turbo=turbo)
74+
@test all(test_y .== test_y_helper)
7275
end
7376

7477
for turbo in [false, true], T in [Float16, Float32, Float64, ComplexF32, ComplexF64]
@@ -110,7 +113,7 @@ for turbo in [false, true], T in [Float16, Float32, Float64, ComplexF32, Complex
110113
x1 = Node(T; feature=1)
111114
tree = sin(x1 / 0.0)
112115
X = randn(Float32, 3, 10)
113-
@test isnan(tree(X; turbo=turbo)[1])
116+
@test isnan(tree(X, operators; turbo=turbo)[1])
114117
end
115118

116119
# Check if julia version >= 1.7:
@@ -124,7 +127,7 @@ if VERSION >= v"1.7"
124127
X = randn(Float32, 10)
125128
local stack
126129
try
127-
tree(X)[1]
130+
tree(X, operators)[1]
128131
@test false
129132
catch e
130133
@test e isa ErrorException
@@ -137,10 +140,10 @@ if VERSION >= v"1.7"
137140

138141
# If a method is not defined, we should get a nothing:
139142
X = randn(Float32, 1, 10)
140-
@test tree(X; throw_errors=false) === nothing
143+
@test tree(X, operators; throw_errors=false) === nothing
141144
# or a MethodError:
142145
try
143-
tree(X; throw_errors=true)
146+
tree(X, operators; throw_errors=true)
144147
@test false
145148
catch e
146149
@test e isa ErrorException

test/test_generic_operators.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ operators = GenericOperatorEnum(; binary_operators=(*,))
77

88
x1, x2, x3 = [Node(String; feature=i) for i in 1:3]
99
tree = x1 * " " * "World!"
10-
@test tree(["Hello"]) == "Hello World!"
10+
@test tree(["Hello"], operators) == "Hello World!"

test/test_tensor_operators.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,19 @@ X = [[2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]
1717

1818
tree = Node(1, c1, x2)
1919
@test repr(tree) == "vec_add([1.0, 2.0, 3.0], x2)"
20-
@test tree(X) == [4.0, 5.0, 6.0]
20+
@test tree(X, operators) == [4.0, 5.0, 6.0]
2121
tree = Node(1, x1, c1)
2222
@test repr(tree) == "vec_add(x1, [1.0, 2.0, 3.0])"
23-
@test tree(X) == [3.0, 4.0, 5.0]
23+
@test tree(X, operators) == [3.0, 4.0, 5.0]
2424

2525
# Try same things, but with constructors:
2626
@extend_operators operators
2727
tree = vec_add(c1, x2)
2828
@test repr(tree) == "vec_add([1.0, 2.0, 3.0], x2)"
29-
@test tree(X) == [4.0, 5.0, 6.0]
29+
@test tree(X, operators) == [4.0, 5.0, 6.0]
3030
tree = vec_add(x1, c1)
3131
@test repr(tree) == "vec_add(x1, [1.0, 2.0, 3.0])"
32-
@test tree(X) == [3.0, 4.0, 5.0]
32+
@test tree(X, operators) == [3.0, 4.0, 5.0]
3333

3434
# Also test unary operators:
3535
function vec_square(x)
@@ -40,17 +40,17 @@ operators = GenericOperatorEnum(; binary_operators=[vec_add], unary_operators=[v
4040
@extend_operators operators
4141
tree = Node(1, c1)
4242
@test repr(tree) == "vec_square([1.0, 2.0, 3.0])"
43-
@test tree(X) == [1.0, 4.0, 9.0]
43+
@test tree(X, operators) == [1.0, 4.0, 9.0]
4444
@test vec_square(c1).val == [1.0, 4.0, 9.0]
4545
tree = Node(1, Node(1, c1), x1)
4646
@test repr(tree) == "vec_add(vec_square([1.0, 2.0, 3.0]), x1)"
47-
@test tree(X) == [3.0, 6.0, 11.0]
47+
@test tree(X, operators) == [3.0, 6.0, 11.0]
4848
@test (vec_add(vec_square(c1), x1))(X) == [3.0, 6.0, 11.0]
4949

5050
# Also test mixed scalar and floats:
5151
c2 = Node(T; val=2.0)
5252
@test repr(c2) == "2.0"
5353
tree = Node(1, Node(1, c1, x1), c2)
5454
@test repr(tree) == "vec_add(vec_add([1.0, 2.0, 3.0], x1), 2.0)"
55-
tree(X)
56-
@test tree(X) == [5.0, 6.0, 7.0]
55+
tree(X, operators)
56+
@test tree(X, operators) == [5.0, 6.0, 7.0]

0 commit comments

Comments
 (0)