@@ -13,12 +13,10 @@ using ..NodeModule:
1313 default_allocator,
1414 with_type_parameters,
1515 leaf_copy,
16- leaf_copy!,
1716 leaf_convert,
1817 leaf_hash,
1918 leaf_equal,
2019 branch_copy,
21- branch_copy!,
2220 branch_convert,
2321 branch_hash,
2422 branch_equal,
@@ -38,6 +36,8 @@ using ..NodeUtilsModule:
3836 has_constants,
3937 get_scalar_constants,
4038 set_scalar_constants!
39+ using .. NodePreallocationModule:
40+ copy_into!, leaf_copy_into!, branch_copy_into!, allocate_container
4141using .. StringsModule: string_tree
4242using .. EvaluateModule: eval_tree_array
4343using .. EvaluateDerivativeModule: eval_grad_tree_array
@@ -96,6 +96,11 @@ function _check_with_metadata(ex::AbstractExpression)
9696end
9797
9898# # optional
99+ function _check_copy_into! (ex:: AbstractExpression )
100+ container = allocate_container (ex)
101+ prealloc_ex = copy_into! (container, ex)
102+ return container != = nothing && prealloc_ex == ex && prealloc_ex != = ex
103+ end
99104function _check_count_nodes (ex:: AbstractExpression )
100105 return count_nodes (ex) isa Int64
101106end
@@ -156,6 +161,7 @@ ei_components = (
156161 with_metadata = " returns the expression with different metadata" => _check_with_metadata,
157162 ),
158163 optional = (
164+ copy_into! = " copies an expression into a preallocated container" => _check_copy_into!,
159165 count_nodes = " counts the number of nodes in the expression tree" => _check_count_nodes,
160166 count_constant_nodes = " counts the number of constant nodes in the expression tree" => _check_count_constant_nodes,
161167 count_depth = " calculates the depth of the expression tree" => _check_count_depth,
@@ -260,14 +266,19 @@ function _check_tree_mapreduce(tree::AbstractExpressionNode)
260266end
261267
262268# # optional
269+ function _check_copy_into! (tree:: AbstractExpressionNode )
270+ container = allocate_container (tree)
271+ prealloc_tree = copy_into! (container, tree)
272+ return container != = nothing && prealloc_tree == tree && prealloc_tree != = container
273+ end
263274function _check_leaf_copy (tree:: AbstractExpressionNode )
264275 tree. degree != 0 && return true
265276 return leaf_copy (tree) isa typeof (tree)
266277end
267- function _check_leaf_copy ! (tree:: AbstractExpressionNode{T} ) where {T}
278+ function _check_leaf_copy_into ! (tree:: AbstractExpressionNode{T} ) where {T}
268279 tree. degree != 0 && return true
269280 new_leaf = constructorof (typeof (tree))(; val= zero (T))
270- ret = leaf_copy ! (new_leaf, tree)
281+ ret = leaf_copy_into ! (new_leaf, tree)
271282 return new_leaf == tree && ret === new_leaf
272283end
273284function _check_leaf_convert (tree:: AbstractExpressionNode )
@@ -292,16 +303,16 @@ function _check_branch_copy(tree::AbstractExpressionNode)
292303 return branch_copy (tree, tree. l, tree. r) isa typeof (tree)
293304 end
294305end
295- function _check_branch_copy ! (tree:: AbstractExpressionNode{T} ) where {T}
306+ function _check_branch_copy_into ! (tree:: AbstractExpressionNode{T} ) where {T}
296307 if tree. degree == 0
297308 return true
298309 end
299310 new_branch = constructorof (typeof (tree))(; val= zero (T))
300311 if tree. degree == 1
301- ret = branch_copy ! (new_branch, tree, copy (tree. l))
312+ ret = branch_copy_into ! (new_branch, tree, copy (tree. l))
302313 return new_branch == tree && ret === new_branch
303314 else
304- ret = branch_copy ! (new_branch, tree, copy (tree. l), copy (tree. r))
315+ ret = branch_copy_into ! (new_branch, tree, copy (tree. l), copy (tree. r))
305316 return new_branch == tree && ret === new_branch
306317 end
307318end
@@ -372,13 +383,14 @@ ni_components = (
372383 tree_mapreduce = " applies a function across the tree" => _check_tree_mapreduce
373384 ),
374385 optional = (
386+ copy_into! = " copies a node into a preallocated container" => _check_copy_into!,
375387 leaf_copy = " copies a leaf node" => _check_leaf_copy,
376- leaf_copy ! = " copies a leaf node in-place" => _check_leaf_copy !,
388+ leaf_copy_into ! = " copies a leaf node in-place" => _check_leaf_copy_into !,
377389 leaf_convert = " converts a leaf node" => _check_leaf_convert,
378390 leaf_hash = " computes the hash of a leaf node" => _check_leaf_hash,
379391 leaf_equal = " checks equality of two leaf nodes" => _check_leaf_equal,
380392 branch_copy = " copies a branch node" => _check_branch_copy,
381- branch_copy ! = " copies a branch node in-place" => _check_branch_copy !,
393+ branch_copy_into ! = " copies a branch node in-place" => _check_branch_copy_into !,
382394 branch_convert = " converts a branch node" => _check_branch_convert,
383395 branch_hash = " computes the hash of a branch node" => _check_branch_hash,
384396 branch_equal = " checks equality of two branch nodes" => _check_branch_equal,
@@ -419,7 +431,7 @@ ni_description = (
419431 [Arguments ()]
420432)
421433@implements (
422- NodeInterface{all_ni_methods_except ((:leaf_copy! , :branch_copy! ))},
434+ NodeInterface{all_ni_methods_except (())},
423435 GraphNode,
424436 [Arguments ()]
425437)
0 commit comments