Skip to content

indexing bug: wrong in-place updates for 4D arrays with paired integer indices #2783

@abagusetty

Description

@abagusetty

The following pattern with a 4D array and paired integer index arrays produces incorrect results. With NumPy and CuPy, this advanced indexing expression updates only the diagonal blocks hj[0, 0, :, :] and hj[1, 1, :, :], leaving the off-diagonal blocks unchanged, as defined by NumPy’s semantics. In dpnp, however, the same code corrupts the off-diagonal blocks as well.

reproducer:

import dpnp as np

natm = 2
hj = np.arange(2 * 2 * 3 * 3, dtype=float).reshape(2, 2, 3, 3)
tmp = -np.ones((natm, 3, 3), dtype=float)

print("hj before:\n", hj)
hj[range(natm), range(natm)] += 2.0 * tmp
print("hj after:\n", hj)

dpnp output:

hj before:
 [[[[ 0.  1.  2.]
   [ 3.  4.  5.]
   [ 6.  7.  8.]]

  [[ 9. 10. 11.]
   [12. 13. 14.]
   [15. 16. 17.]]]


 [[[18. 19. 20.]
   [21. 22. 23.]
   [24. 25. 26.]]

  [[27. 28. 29.]
   [30. 31. 32.]
   [33. 34. 35.]]]]
hj after:
 [[[[-2. -1.  0.]
   [ 1.  2.  3.]
   [ 4.  5.  6.]]

  [[ 7.  8.  9.]
   [10. 11. 12.]
   [13. 14. 15.]]]


 [[[16. 17. 18.]
   [19. 20. 21.]
   [22. 23. 24.]]

  [[25. 26. 27.]
   [28. 29. 30.]
   [31. 32. 33.]]]

vs (numpy/cupy) output:

hj before:
 [[[[ 0.  1.  2.]
   [ 3.  4.  5.]
   [ 6.  7.  8.]]

  [[ 9. 10. 11.]
   [12. 13. 14.]
   [15. 16. 17.]]]


 [[[18. 19. 20.]
   [21. 22. 23.]
   [24. 25. 26.]]

  [[27. 28. 29.]
   [30. 31. 32.]
   [33. 34. 35.]]]]
hj after:
 [[[[-2. -1.  0.]
   [ 1.  2.  3.]
   [ 4.  5.  6.]]

  [[ 9. 10. 11.]
   [12. 13. 14.]
   [15. 16. 17.]]]


 [[[18. 19. 20.]
   [21. 22. 23.]
   [24. 25. 26.]]

  [[25. 26. 27.]
   [28. 29. 30.]
   [31. 32. 33.]]]]

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No fields configured for Bug.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions