Skip to content

Commit 34d1018

Browse files
committed
Assert we can extend operators in other packages
1 parent d118673 commit 34d1018

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

test/test_custom_operators.jl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,26 @@ end
5252

5353
end
5454

55-
# Now, test that we can work with operators defined in other modules
5655
import .A: create_and_eval_tree
5756
prediction, truth = create_and_eval_tree()
5857
@test prediction truth
58+
59+
# Now, test that we can work with operators defined in other modules
60+
module B
61+
62+
my_func_c(x::T, y::T) where {T<:Real} = x * y + T(0.3)
63+
my_func_d(x::T) where {T<:Real} = x / (abs(x)^T(0.2) + 0.1)
64+
65+
end
66+
67+
import .B: my_func_c, my_func_d
68+
operators = OperatorEnum(; binary_operators=[my_func_c], unary_operators=[my_func_d])
69+
@extend_operators operators
70+
71+
x1 = Node(Float64; feature=1)
72+
x2 = Node(Float64; feature=2)
73+
c1 = Node(Float64; val=0.2)
74+
tree = my_func_c(my_func_c(x2, 0.2), my_func_d(x1))
75+
func = (x1, x2) -> my_func_c(my_func_c(x2, 0.2), my_func_d(x1))
76+
X = randn(MersenneTwister(0), 2, 20)
77+
@test tree(X) func.(X[1, :], X[2, :])

0 commit comments

Comments
 (0)