Skip to content

indexing bug: wrong in-place updates for 4D arrays with paired integer indices #2783

@abagusetty

Description

@abagusetty

The following pattern with a 4D array and paired integer index arrays produces incorrect results. With NumPy and CuPy, this advanced indexing expression updates only the diagonal blocks hj[0, 0, :, :] and hj[1, 1, :, :], leaving the off-diagonal blocks unchanged, as defined by NumPy’s semantics. In dpnp, however, the same code corrupts the off-diagonal blocks as well.

reproducer:

import dpnp as np

natm = 2
hj = np.arange(2 * 2 * 3 * 3, dtype=float).reshape(2, 2, 3, 3)
tmp = -np.ones((natm, 3, 3), dtype=float)

print("hj before:\n", hj)
hj[range(natm), range(natm)] += 2.0 * tmp
print("hj after:\n", hj)

dpnp output:

hj before:
 [[[[ 0.  1.  2.]
   [ 3.  4.  5.]
   [ 6.  7.  8.]]

  [[ 9. 10. 11.]
   [12. 13. 14.]
   [15. 16. 17.]]]


 [[[18. 19. 20.]
   [21. 22. 23.]
   [24. 25. 26.]]

  [[27. 28. 29.]
   [30. 31. 32.]
   [33. 34. 35.]]]]
hj after:
 [[[[-2. -1.  0.]
   [ 1.  2.  3.]
   [ 4.  5.  6.]]

  [[ 7.  8.  9.]
   [10. 11. 12.]
   [13. 14. 15.]]]


 [[[16. 17. 18.]
   [19. 20. 21.]
   [22. 23. 24.]]

  [[25. 26. 27.]
   [28. 29. 30.]
   [31. 32. 33.]]]

vs (numpy/cupy) output:

hj before:
 [[[[ 0.  1.  2.]
   [ 3.  4.  5.]
   [ 6.  7.  8.]]

  [[ 9. 10. 11.]
   [12. 13. 14.]
   [15. 16. 17.]]]


 [[[18. 19. 20.]
   [21. 22. 23.]
   [24. 25. 26.]]

  [[27. 28. 29.]
   [30. 31. 32.]
   [33. 34. 35.]]]]
hj after:
 [[[[-2. -1.  0.]
   [ 1.  2.  3.]
   [ 4.  5.  6.]]

  [[ 9. 10. 11.]
   [12. 13. 14.]
   [15. 16. 17.]]]


 [[[18. 19. 20.]
   [21. 22. 23.]
   [24. 25. 26.]]

  [[25. 26. 27.]
   [28. 29. 30.]
   [31. 32. 33.]]]]

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions