Skip to content

Commit 11b63f5

Browse files
[API compatibility] Modify the meshgrid API and add the Tensor.split_with_sizes API. (#76132)
* Modify the meshgrid API and add the Tensor.split_with_sizes API. * update * update * update
1 parent 21d46fd commit 11b63f5

File tree

4 files changed

+207
-6
lines changed

4 files changed

+207
-6
lines changed

python/paddle/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
range,
5858
resize_,
5959
set_,
60+
split_with_sizes,
6061
to_tensor,
6162
tril,
6263
tril_,
@@ -949,6 +950,7 @@
949950
'greater',
950951
'clamp',
951952
'clamp_',
953+
'split_with_sizes',
952954
]
953955

954956

python/paddle/tensor/creation.py

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2556,13 +2556,15 @@ def triu_(
25562556

25572557
@overload
25582558
def meshgrid(
2559-
args: Sequence[paddle.Tensor], name: str | None = None
2559+
args: Sequence[paddle.Tensor],
2560+
name: str | None = None,
2561+
indexing: str | None = None,
25602562
) -> list[paddle.Tensor]: ...
25612563

25622564

25632565
@overload
25642566
def meshgrid(
2565-
*args: paddle.Tensor, name: str | None = None
2567+
*args: paddle.Tensor, name: str | None = None, indexing: str | None = None
25662568
) -> list[paddle.Tensor]: ...
25672569

25682570

@@ -2577,7 +2579,9 @@ def meshgrid(*args, **kwargs):
25772579
**kwargs (optional): Currently, only accept name in **kwargs
25782580
The default value is None. Normally there is no need for
25792581
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
2580-
2582+
indexing (Optional[str]) : the indexing mode, either “xy” or “ij”, defaults to “ij”.If “xy” is selected, the first dimension corresponds to the cardinality
2583+
of the second input and the second dimension corresponds to the cardinality of the first input. If “ij” is selected, the dimensions are in the
2584+
same order as the cardinality of the inputs.
25812585
Returns:
25822586
Tensor: k tensors. The shape of each tensor is (N1, N2, ..., Nk)
25832587
@@ -2597,13 +2601,26 @@ def meshgrid(*args, **kwargs):
25972601
[100, 200]
25982602
25992603
"""
2604+
name = kwargs.get("name", None)
2605+
indexing = kwargs.pop("indexing", None)
2606+
if indexing is None:
2607+
indexing = "ij"
26002608

26012609
if len(args) == 1 and isinstance(args[0], (list, tuple)):
26022610
args = args[0]
2611+
2612+
if indexing not in ("ij", "xy"):
2613+
raise ValueError(
2614+
f"meshgrid: indexing must be 'ij' or 'xy', but got {indexing}"
2615+
)
2616+
2617+
swap_xy = indexing == "xy" and len(args) >= 2
2618+
if swap_xy:
2619+
args = (args[1], args[0], *args[2:])
2620+
26032621
if in_dynamic_or_pir_mode():
2604-
return _C_ops.meshgrid(list(args))
2622+
out = _C_ops.meshgrid(list(args))
26052623
else:
2606-
name = kwargs.get("name", None)
26072624
helper = LayerHelper('meshgrid', **locals())
26082625

26092626
if not isinstance(args, (list, tuple)):
@@ -2637,7 +2654,59 @@ def meshgrid(*args, **kwargs):
26372654
type='meshgrid', inputs={'X': list(args)}, outputs={'Out': out}
26382655
)
26392656

2640-
return out
2657+
if swap_xy:
2658+
out[0], out[1] = out[1], out[0]
2659+
return out
2660+
2661+
2662+
def split_with_sizes(
2663+
self: paddle.Tensor, split_sizes: list[int], dim: int = 0
2664+
) -> list[paddle.Tensor]:
2665+
"""
2666+
Splits the input tensor into multiple sub tensors according to given split sizes.
2667+
2668+
Args:
2669+
self (Tensor): The input tensor to be split.
2670+
split_sizes (list[int]): A list of non negative integers specifying
2671+
the sizes of each split along dimension ``dim``. The sum of all
2672+
elements in this list must equal the size of ``self`` along ``dim``.
2673+
dim (int, optional): The dimension along which to split the tensor.
2674+
Defaults to 0.
2675+
2676+
Returns:
2677+
list[Tensor]: A list of sub tensors resulting from splitting ``self``
2678+
along the specified dimension.
2679+
2680+
Examples:
2681+
.. code-block:: python
2682+
2683+
>>> import paddle
2684+
>>> x = paddle.to_tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
2685+
>>> # Split into two parts along the first dimension, of sizes 1 and 2
2686+
>>> splits = paddle.Tensor.split_with_sizes(x, [1, 2], dim=0)
2687+
>>> print(splits)
2688+
"""
2689+
for size in split_sizes:
2690+
if size < 0:
2691+
raise ValueError(
2692+
"split_with_sizes expects split_sizes have only non-negative entries"
2693+
)
2694+
2695+
total = sum(split_sizes)
2696+
if total != self.shape[dim]:
2697+
raise ValueError(
2698+
f"Split sizes add up to {total} but got the tensor's size of {self.shape[dim]}"
2699+
)
2700+
2701+
outs = []
2702+
start = 0
2703+
for size in split_sizes:
2704+
end = start + size
2705+
out = paddle.slice(self, axes=[dim], starts=[start], ends=[end])
2706+
outs.append(out)
2707+
start = end
2708+
2709+
return outs
26412710

26422711

26432712
def diag_embed(

test/legacy_test/test_meshgrid_op.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,74 @@ def test_api_with_dygraph(self):
311311
np.testing.assert_array_equal(res_4.shape, [100, 200])
312312

313313

314+
class TestMeshgridOpIndexing(unittest.TestCase):
315+
def setUp(self):
316+
self.input_3 = np.random.randint(0, 100, [100]).astype('int32')
317+
self.input_4 = np.random.randint(0, 100, [200]).astype('int32')
318+
319+
def test_api_with_dygraph_indexing_xy(self):
320+
np_res_3, np_res_4 = np.meshgrid(
321+
self.input_3, self.input_4, indexing='xy'
322+
)
323+
324+
with base.dygraph.guard():
325+
tensor_3 = paddle.to_tensor(self.input_3)
326+
tensor_4 = paddle.to_tensor(self.input_4)
327+
res_3, res_4 = paddle.tensor.meshgrid(
328+
tensor_3, tensor_4, indexing='xy'
329+
)
330+
331+
np.testing.assert_array_equal(res_3.shape, np_res_3.shape)
332+
np.testing.assert_array_equal(res_4.shape, np_res_4.shape)
333+
np.testing.assert_array_equal(res_3.numpy(), np_res_3)
334+
np.testing.assert_array_equal(res_3.numpy(), np_res_3)
335+
np.testing.assert_array_equal(res_4.numpy(), np_res_4)
336+
337+
def test_api_with_dygraph_indexing_ij(self):
338+
np_res_3, np_res_4 = np.meshgrid(
339+
self.input_3, self.input_4, indexing='ij'
340+
)
341+
342+
with base.dygraph.guard():
343+
tensor_3 = paddle.to_tensor(self.input_3)
344+
tensor_4 = paddle.to_tensor(self.input_4)
345+
res_3, res_4 = paddle.tensor.meshgrid(
346+
tensor_3, tensor_4, indexing='ij'
347+
)
348+
349+
np.testing.assert_array_equal(res_3.shape, np_res_3.shape)
350+
np.testing.assert_array_equal(res_4.shape, np_res_4.shape)
351+
np.testing.assert_array_equal(res_3.numpy(), np_res_3)
352+
np.testing.assert_array_equal(res_4.numpy(), np_res_4)
353+
354+
def test_indexing_default(self):
355+
np_res_3, np_res_4 = np.meshgrid(
356+
self.input_3, self.input_4, indexing='ij'
357+
)
358+
359+
with base.dygraph.guard():
360+
tensor_3 = paddle.to_tensor(self.input_3)
361+
tensor_4 = paddle.to_tensor(self.input_4)
362+
res_3, res_4 = paddle.tensor.meshgrid(tensor_3, tensor_4)
363+
res_3_n, res_4_n = paddle.tensor.meshgrid(
364+
tensor_3, tensor_4, indexing=None
365+
)
366+
np.testing.assert_array_equal(res_3.numpy(), np_res_3)
367+
np.testing.assert_array_equal(res_4.numpy(), np_res_4)
368+
np.testing.assert_array_equal(res_3_n.numpy(), np_res_3)
369+
np.testing.assert_array_equal(res_4_n.numpy(), np_res_4)
370+
371+
def test_indexing_invalid_value(self):
372+
with base.dygraph.guard():
373+
tensor_3 = paddle.to_tensor(self.input_3)
374+
tensor_4 = paddle.to_tensor(self.input_4)
375+
invalid_indexing = "ab"
376+
with self.assertRaises(ValueError) as cm:
377+
res_3, res_4 = paddle.tensor.meshgrid(
378+
tensor_3, tensor_4, indexing=invalid_indexing
379+
)
380+
381+
314382
class TestMeshgridOp7(unittest.TestCase):
315383
def test_api_with_dygraph_list_input(self):
316384
input_3 = np.random.randint(
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import unittest
17+
18+
import numpy as np
19+
20+
import paddle
21+
22+
23+
class TestSplitWithSizes(unittest.TestCase):
24+
def setUp(self):
25+
self.x = paddle.arange(12).reshape([3, 4])
26+
self.split_sizes = [1, 2]
27+
self.dim = 0
28+
29+
def test_basic_functionality(self):
30+
splits = paddle.Tensor.split_with_sizes(
31+
self.x, self.split_sizes, dim=self.dim
32+
)
33+
34+
self.assertEqual(len(splits), len(self.split_sizes))
35+
36+
expected_shapes = [[1, 4], [2, 4]]
37+
for s, shape in zip(splits, expected_shapes):
38+
self.assertListEqual(list(s.shape), shape)
39+
40+
np_x = self.x.numpy()
41+
start = 0
42+
for i, size in enumerate(self.split_sizes):
43+
np_ref = np_x[start : start + size, :]
44+
np.testing.assert_array_equal(splits[i].numpy(), np_ref)
45+
start += size
46+
47+
def test_ValueError_raises(self):
48+
invalid_split_sizes = [1, -2]
49+
with self.assertRaises(ValueError) as cm:
50+
paddle.Tensor.split_with_sizes(
51+
self.x, invalid_split_sizes, dim=self.dim
52+
)
53+
54+
invalid_split_sizes = [1, 1]
55+
with self.assertRaises(ValueError) as cm:
56+
paddle.Tensor.split_with_sizes(
57+
self.x, invalid_split_sizes, dim=self.dim
58+
)
59+
60+
61+
if __name__ == "__main__":
62+
unittest.main()

0 commit comments

Comments
 (0)