Skip to content

Commit 2de1c72

Browse files
committed
apply review feedback
copy inputs to SpecializationConstant to prevent dangling pointers
1 parent 7e6065f commit 2de1c72

3 files changed

Lines changed: 37 additions & 19 deletions

File tree

dpctl/program/_program.pyx

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ from cpython.buffer cimport (
3737
from cpython.bytes cimport PyBytes_FromStringAndSize
3838
from libc.stdint cimport uint32_t
3939
from libc.stdlib cimport free, malloc
40-
from libc.string cimport memcmp
40+
from libc.string cimport memcmp, memcpy
4141

4242
import warnings
4343

@@ -298,12 +298,10 @@ cdef class SpecializationConstant:
298298
integers, the first argument is interpreted as the number of bytes and
299299
the second argument is interpreted as a pointer to the data.
300300
301-
Note that when constructing from a buffer, the
302-
:class:`.SpecializationConstant`, shares memory with the original object.
303-
Modifications to the original object's data after construction will be
304-
reflected when the :class:`.SpecializationConstant` is used to create a
305-
:class:`.SyclKernelBundle`. This is not the case when constructing from a
306-
raw pointer, as the data is copied.
301+
Note that construction of the :class:`.SpecializationConstant` copies the
302+
input, so modifications made after construction of the
303+
:class:`.SpecializationConstant` will not be reflected in the
304+
:class:`.SyclKernelBundle`.
307305
308306
Args:
309307
spec_id (int):
@@ -319,11 +317,12 @@ cdef class SpecializationConstant:
319317
"""
320318

321319
cdef _spec_const _spec_const
322-
cdef Py_buffer _buffer
323320

324321
def __cinit__(self, spec_id, *args):
325322
cdef int ret_code = 0
326323
cdef object target_obj = None
324+
cdef Py_buffer _local_buffer
325+
cdef void *copied_data
327326

328327
if not isinstance(spec_id, numbers.Integral):
329328
raise TypeError(
@@ -348,16 +347,16 @@ cdef class SpecializationConstant:
348347
)
349348
elif isinstance(args[0], str):
350349
target_obj = np.ascontiguousarray(args[1], dtype=args[0])
350+
else:
351+
raise TypeError(
352+
"Invalid arguments."
353+
)
351354

352355
elif len(args) == 1:
353356
target_obj = args[0]
354357
if not PyObject_CheckBuffer(target_obj):
355358
# attempt to coerce to a numpy array
356359
target_obj = np.ascontiguousarray(target_obj)
357-
else:
358-
raise TypeError(
359-
"Invalid arguments."
360-
)
361360

362361
if isinstance(target_obj, np.ndarray):
363362
if target_obj.dtype.kind not in ("b", "i", "u", "f", "c"):
@@ -372,17 +371,30 @@ cdef class SpecializationConstant:
372371
)
373372

374373
ret_code = PyObject_GetBuffer(
375-
target_obj, &(self._buffer), PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS
374+
target_obj, &(_local_buffer), PyBUF_SIMPLE | PyBUF_ANY_CONTIGUOUS
376375
)
377376
if ret_code != 0:
378377
raise ValueError(
379378
"Failed to get buffer view for the provided object."
380379
)
381-
self._spec_const.value = <void*>self._buffer.buf
382-
self._spec_const.size = <size_t>self._buffer.len
380+
381+
self._spec_const.size = <size_t>_local_buffer.len
382+
copied_data = malloc(self._spec_const.size)
383+
384+
if copied_data == NULL:
385+
PyBuffer_Release(&(_local_buffer))
386+
raise MemoryError(
387+
"Failed to allocate memory for specialization constant data."
388+
)
389+
390+
memcpy(copied_data, _local_buffer.buf, self._spec_const.size)
391+
self._spec_const.value = copied_data
392+
393+
PyBuffer_Release(&(_local_buffer))
383394

384395
def __dealloc__(self):
385-
PyBuffer_Release(&(self._buffer))
396+
if self._spec_const.value != NULL:
397+
free(<void*>self._spec_const.value)
386398

387399
def __repr__(self):
388400
return f"SpecializationConstant({self._spec_const.id})"

dpctl/program/utils/_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ def parse_spirv_specializations(
106106

107107
if word_count == 0:
108108
raise ValueError(f"Invalid SPIR-V instruction at word index {i}")
109+
if i + word_count > len(words):
110+
raise ValueError(
111+
f"Invalid SPIR-V instruction at offset {i} (extends beyond "
112+
"buffer)"
113+
)
109114

110115
if opcode == SpirvOpCode.OpFunction:
111116
# everything following is not relevant to specialization constant
@@ -173,12 +178,14 @@ def parse_spirv_specializations(
173178
dtype_str = type_info["dtype"]
174179
raw_default = defaults.get(target_id)
175180
default_value = None
176-
if isinstance(raw_default, bytes):
181+
if isinstance(raw_default, bool):
182+
default_value = raw_default
183+
elif isinstance(raw_default, bytes) and dtype_str != "unknown_type":
177184
try:
178185
default_value = np.frombuffer(raw_default, dtype=dtype_str)[
179186
0
180187
].item()
181-
except Exception:
188+
except (ValueError, TypeError):
182189
default_value = None
183190

184191
result.append(

libsyclinterface/source/dpctl_sycl_kernel_bundle_interface.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,6 @@ _CreateKernelBundleWithIL_ze_impl(const context &SyclCtx,
503503
backend_traits<ze_be>::return_type<device> ZeDevice;
504504
ZeDevice = get_native<ze_be>(SyclDev);
505505

506-
// Specialization constants are not supported by DPCTL at the moment
507506
std::vector<std::uint32_t> spec_ids;
508507
std::vector<const void *> spec_values;
509508

0 commit comments

Comments
 (0)