11import os
2- from numbers import Number
2+ from numbers import Number , Integral
33import numpy as np
44from mpi4py import MPI
55from .pencil import Pencil , Subcomm
@@ -24,12 +24,14 @@ class DistArray(np.ndarray):
2424 dtype : np.dtype, optional
2525 Type of array
2626 buffer : Numpy array, optional
27- Array of correct shape
27+ Array of correct shape. The buffer owns the memory that is used for
28+ this array.
2829 alignment : None or int, optional
2930 Make sure array is aligned in this direction. Note that alignment does
3031 not take rank into consideration.
3132 rank : int, optional
32- Rank of tensor (scalar is zero, vector one, matrix two)
33+ Rank of tensor (number of free indices, a scalar is zero, vector one,
34+ matrix two)
3335
3436
3537 For more information, see `numpy.ndarray <https://docs.scipy.org/doc/numpy/reference/arrays.ndarray.html>`_
@@ -55,11 +57,12 @@ class DistArray(np.ndarray):
5557 """
5658 def __new__ (cls , global_shape , subcomm = None , val = None , dtype = np .float ,
5759 buffer = None , alignment = None , rank = 0 ):
58- if len (global_shape [rank :]) < 2 :
60+ if len (global_shape [rank :]) < 2 : # 1D case
5961 obj = np .ndarray .__new__ (cls , global_shape , dtype = dtype , buffer = buffer )
6062 if buffer is None and isinstance (val , Number ):
6163 obj .fill (val )
6264 obj ._rank = rank
65+ obj ._p0 = None
6366 return obj
6467
6568 if isinstance (subcomm , Subcomm ):
@@ -155,17 +158,20 @@ def __getitem__(self, i):
155158 if self .ndim == 1 :
156159 return np .ndarray .__getitem__ (self , i )
157160
158- if isinstance (i , (int , slice )) and self .rank > 0 :
161+ if isinstance (i , (Integral , slice )) and self .rank > 0 :
159162 v0 = np .ndarray .__getitem__ (self , i )
160163 v0 ._rank = self .rank - (self .ndim - v0 .ndim )
161- #if v0.ndim < self.ndim:
162- # v0._rank -= 1
163164 return v0
164165
165- if isinstance (i , tuple ) and len (i ) == 2 and self .rank == 2 :
166+ if isinstance (i , (Integral , slice )) and self .rank == 0 :
167+ return np .ndarray .__getitem__ (self .v , i )
168+
169+ assert isinstance (i , tuple )
170+ if len (i ) <= self .rank :
166171 v0 = np .ndarray .__getitem__ (self , i )
167- v0 ._rank = 0
172+ v0 ._rank = self . rank - ( self . ndim - v0 . ndim )
168173 return v0
174+
169175 return np .ndarray .__getitem__ (self .v , i )
170176
171177 @property
@@ -445,7 +451,7 @@ def newDistArray(pfft, forward_output=True, val=0, rank=0, view=False):
445451 val : int or float, optional
446452 Value used to initialize array.
447453 rank: int, optional
448- Scalar has rank 0, vector 1 and matrix 2
454+ Scalar has rank 0, vector 1 and matrix 2.
449455 view : bool, optional
450456 If True return view of the underlying Numpy array, i.e., return
451457 cls.view(np.ndarray). Note that the DistArray still will
0 commit comments