1- ##===---------- dparray .py - dpctl -------*- Python -*----===##
1+ ##===---------- numpy_usm_shared .py - dpctl -------*- Python -*----===##
22##
33## Data Parallel Control (dpCtl)
44##
1919##===----------------------------------------------------------------------===##
2020###
2121### \file
22- ### This file implements a dparray - USM aware implementation of ndarray.
22+ ### This file implements a numpy_usm_shared - USM aware implementation of ndarray.
2323##===----------------------------------------------------------------------===##
2424
2525import numpy as np
@@ -70,12 +70,17 @@ class ndarray(np.ndarray):
7070 with a foreign allocator.
7171 """
7272
73+ external_usm_checkers = []
74+
75+ def add_external_usm_checker (func ):
76+ ndarray .external_usm_checkers .append (func )
77+
7378 def __new__ (
7479 subtype , shape , dtype = float , buffer = None , offset = 0 , strides = None , order = None
7580 ):
7681 # Create a new array.
7782 if buffer is None :
78- dprint ("dparray ::ndarray __new__ buffer None" )
83+ dprint ("numpy_usm_shared ::ndarray __new__ buffer None" )
7984 nelems = np .prod (shape )
8085 dt = np .dtype (dtype )
8186 isz = dt .itemsize
@@ -102,7 +107,7 @@ def __new__(
102107 return new_obj
103108 # zero copy if buffer is a usm backed array-like thing
104109 elif hasattr (buffer , array_interface_property ):
105- dprint ("dparray ::ndarray __new__ buffer" , array_interface_property )
110+ dprint ("numpy_usm_shared ::ndarray __new__ buffer" , array_interface_property )
106111 # also check for array interface
107112 new_obj = np .ndarray .__new__ (
108113 subtype ,
@@ -124,7 +129,7 @@ def __new__(
124129 )
125130 return new_obj
126131 else :
127- dprint ("dparray ::ndarray __new__ buffer not None and not sycl_usm" )
132+ dprint ("numpy_usm_shared ::ndarray __new__ buffer not None and not sycl_usm" )
128133 nelems = np .prod (shape )
129134 # must copy
130135 ar = np .ndarray (
@@ -158,6 +163,9 @@ def __new__(
158163 )
159164 return new_obj
160165
166+ def __sycl_usm_array_interface__ (self ):
167+ return self ._getter_sycl_usm_array_interface ()
168+
161169 def _getter_sycl_usm_array_interface_ (self ):
162170 ary_iface = self .__array_interface__
163171 _base = _get_usm_base (self )
@@ -186,6 +194,9 @@ def __array_finalize__(self, obj):
186194 # subclass of ndarray, including our own.
187195 if hasattr (obj , array_interface_property ):
188196 return
197+ for ext_checker in ndarray .external_usm_checkers :
198+ if ext_checker (obj ):
199+ return
189200 if isinstance (obj , np .ndarray ):
190201 ob = self
191202 while isinstance (ob , np .ndarray ):
@@ -200,7 +211,7 @@ def __array_finalize__(self, obj):
200211 )
201212
202213 # Tell Numba to not treat this type just like a NumPy ndarray but to propagate its type.
203- # This way it will use the custom dparray allocator.
214+ # This way it will use the custom numpy_usm_shared allocator.
204215 __numba_no_subtype_ndarray__ = True
205216
206217 # Convert to a NumPy ndarray.
@@ -234,8 +245,8 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
234245 else :
235246 return NotImplemented
236247 # Have to avoid recursive calls to array_ufunc here.
237- # If no out kwarg then we create a dparray out so that we get
238- # USM memory. However, if kwarg has dparray -typed out then
248+ # If no out kwarg then we create a numpy_usm_shared out so that we get
249+ # USM memory. However, if kwarg has numpy_usm_shared -typed out then
239250 # array_ufunc is called recursively so we cast out as regular
240251 # NumPy ndarray (having a USM data pointer).
241252 if kwargs .get ("out" , None ) is None :
@@ -246,7 +257,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
246257 out_as_np = np .ndarray (out .shape , out .dtype , out )
247258 kwargs ["out" ] = out_as_np
248259 else :
249- # If they manually gave dparray as out kwarg then we have to also
260+ # If they manually gave numpy_usm_shared as out kwarg then we have to also
250261 # cast as regular NumPy ndarray to avoid recursion.
251262 if isinstance (kwargs ["out" ], ndarray ):
252263 out = kwargs ["out" ]
@@ -271,7 +282,7 @@ def isdef(x):
271282 cname = c [0 ]
272283 if isdef (cname ):
273284 continue
274- # For now we do the simple thing and copy the types from NumPy module into dparray module.
285+ # For now we do the simple thing and copy the types from NumPy module into numpy_usm_shared module.
275286 new_func = "%s = np.%s" % (cname , cname )
276287 try :
277288 the_code = compile (new_func , "__init__" , "exec" )
0 commit comments