Skip to content

Commit 8be1f99

Browse files
committed
Fix inference test
1 parent f7e8913 commit 8be1f99

File tree

4 files changed

+14
-15
lines changed

4 files changed

+14
-15
lines changed

src/EvaluateEquationDerivative.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ respect to `x1`.
2828
- `cX::AbstractMatrix{T}`: The data matrix, with each column being a data point.
2929
- `operators::OperatorEnum`: The operators used to create the `tree`.
3030
- `direction::Integer`: The index of the variable to take the derivative with respect to.
31-
- `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
31+
- `turbo::Union{Val,Bool}`: Use `LoopVectorization.@turbo` for faster evaluation.
3232
3333
# Returns
3434
@@ -40,7 +40,7 @@ function eval_diff_tree_array(
4040
cX::AbstractMatrix{T},
4141
operators::OperatorEnum,
4242
direction::Integer;
43-
turbo::Bool=false,
43+
turbo::Union{Val,Bool}=Val(false),
4444
) where {T<:Number}
4545
# TODO: Implement quick check for whether the variable is actually used
4646
# in this tree. Otherwise, return zero.

src/EvaluationHelpers.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,9 @@ end
5959

6060
# Gradients:
6161
function _grad_evaluator(
62-
tree::AbstractExpressionNode, X, operators::OperatorEnum; variable=true, kws...
62+
tree::AbstractExpressionNode, X, operators::OperatorEnum; variable=Val(true), kws...
6363
)
64-
_, grad, did_complete = eval_grad_tree_array(
65-
tree, X, operators; variable=variable, kws...
66-
)
64+
_, grad, did_complete = eval_grad_tree_array(tree, X, operators; variable, kws...)
6765
!did_complete && (grad .= convert(eltype(grad), NaN))
6866
return grad
6967
end
@@ -74,7 +72,7 @@ function _grad_evaluator(
7472
end
7573

7674
"""
77-
(tree::AbstractExpressionNode{T})'(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false, variable::Bool=true)
75+
(tree::AbstractExpressionNode{T})'(X::AbstractMatrix{T}, operators::OperatorEnum; turbo::Union{Bool,Val}=Val(false), variable::Union{Bool,Val}=Val(true))
7876
7977
Compute the forward-mode derivative of an expression, using a similar
8078
structure and optimization to eval_tree_array. `variable` specifies whether
@@ -84,9 +82,9 @@ to every constant in the expression.
8482
# Arguments
8583
- `X::AbstractMatrix{T}`: The data matrix, with each column being a data point.
8684
- `operators::OperatorEnum`: The operators used to create the `tree`.
87-
- `variable::Bool`: Whether to take derivatives with respect to features (i.e., `X` - with `variable=true`),
85+
- `variable::Union{Bool,Val}`: Whether to take derivatives with respect to features (i.e., `X` - with `variable=true`),
8886
or with respect to every constant in the expression (`variable=false`).
89-
- `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
87+
- `turbo::Union{Bool,Val}`: Use `LoopVectorization.@turbo` for faster evaluation.
9088
9189
# Returns
9290

test/test_derivatives.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@ function array_test(ar1, ar2; rtol=0.1)
3737
return isapprox(ar1, ar2; rtol=rtol)
3838
end
3939

40-
for type in [Float16, Float32, Float64], turbo in [true, false]
41-
type == Float16 && turbo && continue
40+
for type in [Float16, Float32, Float64], turbo in [Val(true), Val(false)]
41+
type == Float16 && turbo isa Val{true} && continue
4242

4343
println(
4444
"Testing derivatives with respect to variables, with type=$(type) and turbo=$(turbo).",
@@ -144,15 +144,15 @@ for type in [Float16, Float32, Float64], turbo in [true, false]
144144
end
145145

146146
@testset "NodeIndex" begin
147-
import DynamicExpressions: get_constants, NodeIndex, index_constants
147+
@eval import DynamicExpressions: get_constants, NodeIndex, index_constants
148148

149149
operators = OperatorEnum(;
150150
binary_operators=(+, *, -, /, pow_abs2), unary_operators=(custom_cos, exp, sin)
151151
)
152152
@extend_operators operators
153153
tree = equation3(nx1, nx2, nx3)
154154

155-
"""Check whether the ordering of constant_list is the same as the ordering of node_index."""
155+
# Check whether the ordering of constant_list is the same as the ordering of node_index.
156156
@eval function check_tree(
157157
tree::Node, node_index::NodeIndex, constant_list::AbstractVector
158158
)

test/test_evaluation.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ functions = [
3636
(x1, x2, x3) -> (sin(cos(sin(cos(x1) * x3) * 3.0) * -0.5) + 2.0) * 5.0,
3737
]
3838

39-
for turbo in [false, true], T in [Float16, Float32, Float64, ComplexF32, ComplexF64]
39+
for turbo in [Val(false), Val(true)],
40+
T in [Float16, Float32, Float64, ComplexF32, ComplexF64]
4041
# Float16 not implemented:
41-
turbo && !(T in (Float32, Float64)) && continue
42+
turbo isa Val{true} && !(T in (Float32, Float64)) && continue
4243
@testset "Test evaluation of trees with turbo=$turbo, T=$T" begin
4344
for (i_func, fnc) in enumerate(functions)
4445

0 commit comments

Comments
 (0)