Skip to content

Commit 585a2c0

Browse files
committed
Clean up issues with new constructors in unit tests
1 parent 0c280f9 commit 585a2c0

File tree

8 files changed

+59
-44
lines changed

8 files changed

+59
-44
lines changed

test/test_derivatives.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@ using DynamicExpressions: eval_diff_tree_array, eval_grad_tree_array
44
using Random
55
using Zygote
66
using LinearAlgebra
7+
include("test_params.jl")
78

89
seed = 0
910
# SIMD doesn't like abs(x) ^ y for some reason.
1011
pow_abs2(x, y) = exp(y * log(abs(x)))
11-
custom_cos(x) = cos(x)^2
1212

1313
equation1(x1, x2, x3) = x1 + x2 + x3 + 3.2
1414
equation2(x1, x2, x3) = pow_abs2(x1, x2) + x3 + custom_cos(1.0 + x3) + 3.0 / x1

test/test_initial_errors.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
using DynamicExpressions
2+
using Test
3+
using Zygote
4+
5+
# Before defining OperatorEnum, calling the implicit (deprecated)
6+
# syntax should fail:
7+
tree = Node(; feature=1)
8+
@test_throws ErrorException tree([1.0 2.0]')
9+
@test_throws "Please use the " tree([1.0 2.0]')
10+
@test_throws ErrorException tree'([1.0 2.0]')
11+
@test_throws "Please use the " tree'([1.0 2.0]')
12+
13+
@test string(tree) == "x1"
14+
@test string(Node(1, tree)) == "unary_operator[1](x1)"
15+
@test string(Node(1, tree, tree)) == "binary_operator[1](x1, x1)"
16+
17+
# Also test warnings:
18+
for constructor in (OperatorEnum, GenericOperatorEnum)
19+
operators = constructor(;
20+
binary_operators=[+, -, *, /],
21+
unary_operators=[cos, sin],
22+
(constructor == OperatorEnum ? (enable_autodiff=true,) : ())...,
23+
)
24+
tree([1.0 2.0]')
25+
# Can't test for this:
26+
# expected_warn_msg = "The `tree(X; kws...)` syntax is deprecated"
27+
# @test occursin(expected_warn_msg, msg)
28+
29+
constructor == GenericOperatorEnum && continue
30+
31+
tree'([1.0 2.0]')
32+
# Can't test for this:
33+
# expected_warn_msg = "The `tree'(X; kws...)` syntax is deprecated"
34+
# @test occursin(expected_warn_msg, msg)
35+
end

test/test_params.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ maximum_residual = 1e-2
88
safe_log10(x::T) where {T<:Number} = (x <= 0) ? T(NaN) : log10(x)
99
safe_log1p(x::T) where {T<:Number} = (x <= -1) ? T(NaN) : log1p(x)
1010
safe_sqrt(x::T) where {T<:Number} = (x < 0) ? T(NaN) : sqrt(x)
11-
safe_pow(x::T, y::T) where {T<:Number} = (x < 0 && y != round(y)) ? T(NaN) : x^y
1211
relu(x::T) where {T<:Number} = (x < 0) ? zero(T) : x
1312
safe_acosh(x::T) where {T<:Number} = (x < 1) ? T(NaN) : acosh(x)
1413
sub(x::T, y::T) where {T<:Number} = x - y
@@ -19,15 +18,14 @@ maximum_residual = 1e-2
1918
safe_log10(x) = log10(x)
2019
safe_log1p(x) = log1p(x)
2120
safe_sqrt(x) = sqrt(x)
22-
safe_pow(x, y) = x^y
2321
relu(x) = max(x, 0)
2422
safe_acosh(x) = acosh(x)
2523
sub(x, y) = x - y
2624
square(x) = x * x
2725
cube(x) = x * x * x
2826
greater(x, y) = (x > y)
2927

30-
custom_cos(x) = cos(x)
28+
custom_cos(x) = cos(x)^2
3129
end
3230

3331
HEADER_GUARD_TEST_PARAMS = true

test/test_print.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ for unaop in [safe_log, safe_log2, safe_log10, safe_log1p, safe_sqrt, safe_acosh
3232
@test string_tree(minitree, opts) == replace(string(unaop), "safe_" => "") * "(x1)"
3333
end
3434

35+
!(@isdefined safe_pow) &&
36+
@eval safe_pow(x::T, y::T) where {T<:Number} = (x < 0 && y != round(y)) ? T(NaN) : x^y
3537
for binop in [safe_pow, ^]
3638
opts = OperatorEnum(;
3739
default_params..., binary_operators=(+, *, /, -, binop), unary_operators=(cos,)

test/test_simplification.jl

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,19 +49,11 @@ tree_copy = convert(Node, eqn, operators)
4949
# with custom operators, and unary operators:
5050
x1, x2, x3 = Node("x1"), Node("x2"), Node("x3")
5151
pow_abs2(x, y) = abs(x)^y
52-
custom_cos(x) = cos(x)^2
53-
54-
# Define for Node (usually these are done internally to OperatorEnum)
55-
pow_abs2(l::Node, r::Node)::Node =
56-
(l.constant && r.constant) ? Node(pow_abs2(l.val, r.val)::Real) : Node(5, l, r)
57-
pow_abs2(l::Node, r::Real)::Node =
58-
l.constant ? Node(pow_abs2(l.val, r)::Real) : Node(5, l, r)
59-
pow_abs2(l::Real, r::Node)::Node =
60-
r.constant ? Node(pow_abs2(l, r.val)::Real) : Node(5, l, r)
61-
custom_cos(x::Node)::Node = x.constant ? Node(custom_cos(x.val)::Real) : Node(1, x)
6252

6353
operators = OperatorEnum(;
64-
binary_operators=(+, *, -, /, pow_abs2), unary_operators=(custom_cos, exp, sin)
54+
binary_operators=(+, *, -, /, pow_abs2),
55+
unary_operators=(custom_cos, exp, sin),
56+
define_helper_functions,
6557
)
6658
tree = (
6759
((x2 + x2) * ((-0.5982493 / pow_abs2(x1, x2)) / -0.54734415)) + (

test/test_symbolic_utils.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,24 @@ using Test
44
include("test_params.jl")
55

66
_inv(x) = 1 / x
7-
safe_pow(x::T, y::T) where {T} = (x < 0 && y != round(y)) ? T(NaN) : x^y
8-
greater(x::T, y::T) where {T} = (x > y) ? one(T) : zero(T)
7+
!(@isdefined safe_pow) &&
8+
@eval safe_pow(x::T, y::T) where {T<:Number} = (x < 0 && y != round(y)) ? T(NaN) : x^y
9+
!(@isdefined greater) && @eval greater(x::T, y::T) where {T} = (x > y) ? one(T) : zero(T)
10+
11+
tree =
12+
let tmp_op = OperatorEnum(;
13+
default_params...,
14+
binary_operators=(+, *, ^, /, greater),
15+
unary_operators=(_inv,),
16+
)
17+
Node(5, (Node(; val=3.0) * Node(1, Node("x1")))^2.0, Node(; val=-1.2))
18+
end
19+
920
operators = OperatorEnum(;
1021
default_params...,
1122
binary_operators=(+, *, safe_pow, /, greater),
1223
unary_operators=(_inv,),
1324
)
14-
tree = Node(5, (Node(; val=3.0) * Node(1, Node("x1")))^2.0, Node(; val=-1.2))
1525

1626
eqn = node_to_symbolic(tree, operators; variable_names=["energy"], index_functions=true)
1727
@test string(eqn) == "greater(safe_pow(3.0_inv(energy), 2.0), -1.2)"

test/test_utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ x1, x2, x3 = Node("x1"), Node("x2"), Node("x3")
1616
# Has constants:
1717
@test has_constants(x1) == false
1818
@test has_constants(x1 + 1) == true
19-
@test has_constants(cos(x1)) == false
20-
@test has_constants(cos(Node(; val=0.0))) == true
19+
@test has_constants(sin(x1)) == false
20+
@test has_constants(sin(Node(; val=0.0))) == true
2121

2222
# Has operators
2323
@test has_operators(x1) == false
2424
@test has_operators(x1 + 1) == true
25-
@test has_operators(cos(x1)) == true
25+
@test has_operators(sin(x1)) == true
2626
@test has_operators(Node(; val=0.0)) == false
2727

2828
# Set constants:

test/unittest.jl

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,7 @@
11
using SafeTestsets
22

33
@safetestset "Initial error handling test" begin
4-
using DynamicExpressions
5-
using Test
6-
7-
# Before defining OperatorEnum, calling the implicit (deprecated)
8-
# syntax should fail:
9-
tree = Node(; feature=1)
10-
try
11-
tree([1.0 2.0]')
12-
@test false
13-
catch e
14-
@test isa(e, ErrorException)
15-
expected_error_msg = "The `tree(X; kws...)` syntax is deprecated"
16-
@test occursin(expected_error_msg, e.msg)
17-
end
18-
19-
try
20-
tree'([1.0 2.0]')
21-
@test false
22-
catch e
23-
@test isa(e, ErrorException)
24-
expected_error_msg = "The `tree'(X; kws...)` syntax is deprecated"
25-
@test occursin(expected_error_msg, e.msg)
26-
end
4+
include("test_initial_errors.jl")
275
end
286

297
@safetestset "Test tree construction and scoring" begin

0 commit comments

Comments
 (0)