|
32 | 32 | #include "dpctl4pybind11.hpp" |
33 | 33 |
|
34 | 34 | // Include generated Cython headers for usm_ndarray |
35 | | -// (struct definition and constants only) |
36 | 35 | #include "dpnp/tensor/_usmarray.h" |
37 | 36 | #include "dpnp/tensor/_usmarray_api.h" |
| 37 | +// Include usm_ndarray constants (flags, type numbers) |
| 38 | +#include "../../tensor/include/usm_ndarray_constants.h" |
38 | 39 |
|
39 | 40 | #include <array> |
40 | 41 | #include <cassert> |
@@ -191,47 +192,47 @@ class dpnp_capi |
191 | 192 | this->UsmNDArray_MakeSimpleFromPtr_ = UsmNDArray_MakeSimpleFromPtr; |
192 | 193 | this->UsmNDArray_MakeFromPtr_ = UsmNDArray_MakeFromPtr; |
193 | 194 |
|
194 | | - // constants |
195 | | - this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS; |
196 | | - this->USM_ARRAY_F_CONTIGUOUS_ = USM_ARRAY_F_CONTIGUOUS; |
197 | | - this->USM_ARRAY_WRITABLE_ = USM_ARRAY_WRITABLE; |
198 | | - this->UAR_BOOL_ = UAR_BOOL; |
199 | | - this->UAR_BYTE_ = UAR_BYTE; |
200 | | - this->UAR_UBYTE_ = UAR_UBYTE; |
201 | | - this->UAR_SHORT_ = UAR_SHORT; |
202 | | - this->UAR_USHORT_ = UAR_USHORT; |
203 | | - this->UAR_INT_ = UAR_INT; |
204 | | - this->UAR_UINT_ = UAR_UINT; |
205 | | - this->UAR_LONG_ = UAR_LONG; |
206 | | - this->UAR_ULONG_ = UAR_ULONG; |
207 | | - this->UAR_LONGLONG_ = UAR_LONGLONG; |
208 | | - this->UAR_ULONGLONG_ = UAR_ULONGLONG; |
209 | | - this->UAR_FLOAT_ = UAR_FLOAT; |
210 | | - this->UAR_DOUBLE_ = UAR_DOUBLE; |
211 | | - this->UAR_CFLOAT_ = UAR_CFLOAT; |
212 | | - this->UAR_CDOUBLE_ = UAR_CDOUBLE; |
213 | | - this->UAR_TYPE_SENTINEL_ = UAR_TYPE_SENTINEL; |
214 | | - this->UAR_HALF_ = UAR_HALF; |
| 195 | + // constants from usm_ndarray_constants.h |
| 196 | + this->USM_ARRAY_C_CONTIGUOUS_ = USM_ARRAY_C_CONTIGUOUS_VALUE; |
| 197 | + this->USM_ARRAY_F_CONTIGUOUS_ = USM_ARRAY_F_CONTIGUOUS_VALUE; |
| 198 | + this->USM_ARRAY_WRITABLE_ = USM_ARRAY_WRITABLE_VALUE; |
| 199 | + this->UAR_BOOL_ = UAR_BOOL_VALUE; |
| 200 | + this->UAR_BYTE_ = UAR_BYTE_VALUE; |
| 201 | + this->UAR_UBYTE_ = UAR_UBYTE_VALUE; |
| 202 | + this->UAR_SHORT_ = UAR_SHORT_VALUE; |
| 203 | + this->UAR_USHORT_ = UAR_USHORT_VALUE; |
| 204 | + this->UAR_INT_ = UAR_INT_VALUE; |
| 205 | + this->UAR_UINT_ = UAR_UINT_VALUE; |
| 206 | + this->UAR_LONG_ = UAR_LONG_VALUE; |
| 207 | + this->UAR_ULONG_ = UAR_ULONG_VALUE; |
| 208 | + this->UAR_LONGLONG_ = UAR_LONGLONG_VALUE; |
| 209 | + this->UAR_ULONGLONG_ = UAR_ULONGLONG_VALUE; |
| 210 | + this->UAR_FLOAT_ = UAR_FLOAT_VALUE; |
| 211 | + this->UAR_DOUBLE_ = UAR_DOUBLE_VALUE; |
| 212 | + this->UAR_CFLOAT_ = UAR_CFLOAT_VALUE; |
| 213 | + this->UAR_CDOUBLE_ = UAR_CDOUBLE_VALUE; |
| 214 | + this->UAR_TYPE_SENTINEL_ = UAR_TYPE_SENTINEL_VALUE; |
| 215 | + this->UAR_HALF_ = UAR_HALF_VALUE; |
215 | 216 |
|
216 | 217 | // deduced disjoint types |
217 | | - this->UAR_INT8_ = UAR_BYTE; |
218 | | - this->UAR_UINT8_ = UAR_UBYTE; |
219 | | - this->UAR_INT16_ = UAR_SHORT; |
220 | | - this->UAR_UINT16_ = UAR_USHORT; |
| 218 | + this->UAR_INT8_ = UAR_BYTE_VALUE; |
| 219 | + this->UAR_UINT8_ = UAR_UBYTE_VALUE; |
| 220 | + this->UAR_INT16_ = UAR_SHORT_VALUE; |
| 221 | + this->UAR_UINT16_ = UAR_USHORT_VALUE; |
221 | 222 | this->UAR_INT32_ = |
222 | 223 | platform_typeid_lookup<std::int32_t, long, int, short>( |
223 | | - UAR_LONG, UAR_INT, UAR_SHORT); |
| 224 | + UAR_LONG_VALUE, UAR_INT_VALUE, UAR_SHORT_VALUE); |
224 | 225 | this->UAR_UINT32_ = |
225 | 226 | platform_typeid_lookup<std::uint32_t, unsigned long, unsigned int, |
226 | | - unsigned short>(UAR_ULONG, UAR_UINT, |
227 | | - UAR_USHORT); |
| 227 | + unsigned short>( |
| 228 | + UAR_ULONG_VALUE, UAR_UINT_VALUE, UAR_USHORT_VALUE); |
228 | 229 | this->UAR_INT64_ = |
229 | 230 | platform_typeid_lookup<std::int64_t, long, long long, int>( |
230 | | - UAR_LONG, UAR_LONGLONG, UAR_INT); |
| 231 | + UAR_LONG_VALUE, UAR_LONGLONG_VALUE, UAR_INT_VALUE); |
231 | 232 | this->UAR_UINT64_ = |
232 | 233 | platform_typeid_lookup<std::uint64_t, unsigned long, |
233 | 234 | unsigned long long, unsigned int>( |
234 | | - UAR_ULONG, UAR_ULONGLONG, UAR_UINT); |
| 235 | + UAR_ULONG_VALUE, UAR_ULONGLONG_VALUE, UAR_UINT_VALUE); |
235 | 236 |
|
236 | 237 | py::object py_default_usm_memory = |
237 | 238 | ::dpctl::detail::dpctl_capi::get().default_usm_memory_pyobj(); |
|
0 commit comments