Skip to content

Commit b99ff9c

Browse files
committed
Add new DPNPBinaryTwoOutputsFunc class
1 parent b16668c commit b99ff9c

File tree

3 files changed

+337
-3
lines changed

3 files changed

+337
-3
lines changed

doc/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from dpnp.dpnp_algo.dpnp_elementwise_common import (
1414
DPNPBinaryFunc,
1515
DPNPBinaryFuncOutKw,
16+
DPNPBinaryTwoOutputsFunc,
1617
DPNPUnaryFunc,
1718
DPNPUnaryTwoOutputsFunc,
1819
)
@@ -215,6 +216,7 @@ def _can_document_member(member, *args, **kwargs):
215216
(
216217
DPNPBinaryFunc,
217218
DPNPBinaryFuncOutKw,
219+
DPNPBinaryTwoOutputsFunc,
218220
DPNPUnaryFunc,
219221
DPNPUnaryTwoOutputsFunc,
220222
),

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 310 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,27 @@
3939
BinaryElementwiseFunc,
4040
UnaryElementwiseFunc,
4141
)
42+
from dpctl.tensor._scalar_utils import (
43+
_get_dtype,
44+
_get_shape,
45+
_validate_dtype,
46+
)
4247

4348
import dpnp
4449
import dpnp.backend.extensions.vm._vm_impl as vmi
4550
from dpnp.dpnp_array import dpnp_array
51+
from dpnp.dpnp_utils import get_usm_allocations
4652
from dpnp.dpnp_utils.dpnp_utils_common import (
4753
find_buf_dtype_3out,
54+
find_buf_dtype_4out,
4855
)
4956

5057
__all__ = [
5158
"DPNPI0",
5259
"DPNPAngle",
5360
"DPNPBinaryFunc",
5461
"DPNPBinaryFuncOutKw",
62+
"DPNPBinaryTwoOutputsFunc",
5563
"DPNPFix",
5664
"DPNPImag",
5765
"DPNPReal",
@@ -347,7 +355,7 @@ def __call__(
347355

348356
buf_dt, res1_dt, res2_dt = find_buf_dtype_3out(
349357
x.dtype,
350-
self.result_type_resolver_fn_,
358+
self.get_type_result_resolver_function(),
351359
x.sycl_device,
352360
)
353361
if res1_dt is None or res2_dt is None:
@@ -444,13 +452,12 @@ def __call__(
444452
out[i] = dpt.empty_like(x, dtype=res_dt, order=order)
445453

446454
# Call the unary function with input and output arrays
447-
dep_evs = _manager.submitted_events
448455
ht_unary_ev, unary_ev = self.get_implementation_function()(
449456
x,
450457
dpnp.get_usm_ndarray(out[0]),
451458
dpnp.get_usm_ndarray(out[1]),
452459
sycl_queue=exec_q,
453-
depends=dep_evs,
460+
depends=_manager.submitted_events,
454461
)
455462
_manager.add_event_pair(ht_unary_ev, unary_ev)
456463

@@ -795,6 +802,306 @@ def __call__(self, *args, **kwargs):
795802
return super().__call__(*args, **kwargs)
796803

797804

805+
class DPNPBinaryTwoOutputsFunc(BinaryElementwiseFunc):
806+
"""
807+
Class that implements unary element-wise functions with two output arrays.
808+
809+
Parameters
810+
----------
811+
name : {str}
812+
Name of the unary function
813+
result_type_resolver_fn : {callable}
814+
Function that takes dtype of the input and returns the dtype of
815+
the result if the implementation functions supports it, or
816+
returns `None` otherwise.
817+
unary_dp_impl_fn : {callable}
818+
Data-parallel implementation function with signature
819+
`impl_fn(src: usm_ndarray, dst: usm_ndarray,
820+
sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])`
821+
where the `src` is the argument array, `dst` is the
822+
array to be populated with function values, effectively
823+
evaluating `dst = func(src)`.
824+
The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s.
825+
The first event corresponds to data-management host tasks,
826+
including lifetime management of argument Python objects to ensure
827+
that their associated USM allocation is not freed before offloaded
828+
computational tasks complete execution, while the second event
829+
corresponds to computational tasks associated with function evaluation.
830+
docs : {str}
831+
Documentation string for the unary function.
832+
mkl_fn_to_call : {None, str}
833+
Check input arguments to answer if function from OneMKL VM library
834+
can be used.
835+
mkl_impl_fn : {None, str}
836+
Function from OneMKL VM library to call.
837+
838+
"""
839+
840+
def __init__(
841+
self,
842+
name,
843+
result_type_resolver_fn,
844+
binary_dp_impl_fn,
845+
docs,
846+
):
847+
super().__init__(
848+
name,
849+
result_type_resolver_fn,
850+
binary_dp_impl_fn,
851+
docs,
852+
)
853+
self.__name__ = "DPNPBinaryTwoOutputsFunc"
854+
855+
@property
856+
def nout(self):
857+
"""Returns the number of arguments treated as outputs."""
858+
return 2
859+
860+
def __call__(
861+
self,
862+
x1,
863+
x2,
864+
out1=None,
865+
out2=None,
866+
/,
867+
*,
868+
out=(None, None),
869+
where=True,
870+
order="K",
871+
dtype=None,
872+
subok=True,
873+
**kwargs,
874+
):
875+
if kwargs:
876+
raise NotImplementedError(
877+
f"Requested function={self.name_} with kwargs={kwargs} "
878+
"isn't currently supported."
879+
)
880+
elif where is not True:
881+
raise NotImplementedError(
882+
f"Requested function={self.name_} with where={where} "
883+
"isn't currently supported."
884+
)
885+
elif dtype is not None:
886+
raise NotImplementedError(
887+
f"Requested function={self.name_} with dtype={dtype} "
888+
"isn't currently supported."
889+
)
890+
elif subok is not True:
891+
raise NotImplementedError(
892+
f"Requested function={self.name_} with subok={subok} "
893+
"isn't currently supported."
894+
)
895+
896+
dpnp.check_supported_arrays_type(x1, x2, scalar_type=True)
897+
898+
if order is None:
899+
order = "K"
900+
elif order in "afkcAFKC":
901+
order = order.upper()
902+
else:
903+
raise ValueError(
904+
"order must be one of 'C', 'F', 'A', or 'K' " f"(got '{order}')"
905+
)
906+
907+
res_usm_type, exec_q = get_usm_allocations([x1, x2])
908+
x1 = dpnp.get_usm_ndarray_or_scalar(x1)
909+
x2 = dpnp.get_usm_ndarray_or_scalar(x2)
910+
911+
x1_sh = _get_shape(x1)
912+
x2_sh = _get_shape(x2)
913+
try:
914+
res_shape = dpnp.broadcast_shapes(x1_sh, x2_sh)
915+
except ValueError:
916+
raise ValueError(
917+
"operands could not be broadcast together with shapes "
918+
f"{x1_sh} and {x2_sh}"
919+
)
920+
921+
sycl_dev = exec_q.sycl_device
922+
x1_dt = _get_dtype(x1, sycl_dev)
923+
x2_dt = _get_dtype(x2, sycl_dev)
924+
if not all(_validate_dtype(dt) for dt in [x1_dt, x2_dt]):
925+
raise ValueError("Operands have unsupported data types")
926+
927+
x1_dt, x2_dt = self.get_array_dtype_scalar_type_resolver_function()(
928+
x1_dt, x2_dt, sycl_dev
929+
)
930+
931+
buf1_dt, buf2_dt, res1_dt, res2_dt = find_buf_dtype_4out(
932+
x1_dt,
933+
x2_dt,
934+
self.get_type_result_resolver_function(),
935+
sycl_dev,
936+
)
937+
if res1_dt is None or res2_dt is None:
938+
raise ValueError(
939+
f"function '{self.name_}' does not support input type "
940+
f"({x1_dt}, {x2_dt}), "
941+
"and the input could not be safely coerced to any "
942+
"supported types according to the casting rule ''safe''."
943+
)
944+
buf_dts = [buf1_dt, buf2_dt]
945+
946+
if not isinstance(out, tuple):
947+
raise TypeError("'out' must be a tuple of arrays")
948+
949+
if len(out) != self.nout:
950+
raise ValueError(
951+
"'out' tuple must have exactly one entry per ufunc output"
952+
)
953+
954+
if not (out1 is None and out2 is None):
955+
if all(res is None for res in out):
956+
out = (out1, out2)
957+
else:
958+
raise TypeError(
959+
"cannot specify 'out' as both a positional and keyword argument"
960+
)
961+
962+
orig_out, out = list(out), list(out)
963+
res_dts = [res1_dt, res2_dt]
964+
965+
for i in range(self.nout):
966+
if out[i] is None:
967+
continue
968+
969+
res = dpnp.get_usm_ndarray(out[i])
970+
if not res.flags.writable:
971+
raise ValueError("output array is read-only")
972+
973+
if res.shape != res_shape:
974+
raise ValueError(
975+
"The shape of input and output arrays are inconsistent. "
976+
f"Expected output shape is {res_shape}, got {res.shape}"
977+
)
978+
979+
if dpu.get_execution_queue((exec_q, res.sycl_queue)) is None:
980+
raise dpnp.exceptions.ExecutionPlacementError(
981+
"Input and output allocation queues are not compatible"
982+
)
983+
984+
res_dt = res_dts[i]
985+
if res_dt != res.dtype:
986+
if not dpnp.can_cast(res_dt, res.dtype, casting="same_kind"):
987+
raise TypeError(
988+
f"Cannot cast ufunc '{self.name_}' output {i + 1} from "
989+
f"{res_dt} to {res.dtype} with casting rule 'same_kind'"
990+
)
991+
992+
# Allocate a temporary buffer with the required dtype
993+
out[i] = dpt.empty_like(res, dtype=res_dt)
994+
else:
995+
for x, dt in zip([x1, x2], buf_dts):
996+
if dpnp.isscalar(x):
997+
pass
998+
elif dt is not None:
999+
pass
1000+
elif not dti._array_overlap(x, res):
1001+
pass
1002+
elif dti._same_logical_tensors(x, res):
1003+
pass
1004+
1005+
# Allocate a temporary buffer to avoid memory overlapping.
1006+
# Note if `dt` is not None, a temporary copy of `x` will be
1007+
# created, so the array overlap check isn't needed.
1008+
out[i] = dpt.empty_like(res)
1009+
break
1010+
1011+
x1 = dpnp.as_usm_ndarray(x1, dtype=x1_dt, sycl_queue=exec_q)
1012+
x2 = dpnp.as_usm_ndarray(x2, dtype=x2_dt, sycl_queue=exec_q)
1013+
1014+
if order == "A":
1015+
if x1.flags.f_contiguous and x2.flags.f_contiguous:
1016+
order = "F"
1017+
else:
1018+
order = "C"
1019+
1020+
_manager = dpu.SequentialOrderManager[exec_q]
1021+
dep_evs = _manager.submitted_events
1022+
1023+
# Cast input array to the supported type if needed
1024+
if any(dt is not None for dt in buf_dts):
1025+
if all(dt is not None for dt in buf_dts):
1026+
if x1.flags.c_contiguous and x2.flags.c_contiguous:
1027+
order = "C"
1028+
elif x1.flags.f_contiguous and x2.flags.f_contiguous:
1029+
order = "F"
1030+
1031+
arrs = [x1, x2]
1032+
buf_dts = [buf1_dt, buf2_dt]
1033+
for i in range(self.nout):
1034+
buf_dt = buf_dts[i]
1035+
if buf_dt is None:
1036+
continue
1037+
1038+
x = arrs[i]
1039+
if order == "K":
1040+
buf = dtc._empty_like_orderK(x, buf_dt)
1041+
else:
1042+
buf = dpt.empty_like(x, dtype=buf_dt, order=order)
1043+
1044+
ht_copy_ev, copy_ev = dti._copy_usm_ndarray_into_usm_ndarray(
1045+
src=x, dst=buf, sycl_queue=exec_q, depends=dep_evs
1046+
)
1047+
_manager.add_event_pair(ht_copy_ev, copy_ev)
1048+
1049+
arrs[i] = buf
1050+
x1, x2 = arrs
1051+
1052+
# Allocate a buffer for the output arrays if needed
1053+
for i in range(self.nout):
1054+
if out[i] is None:
1055+
res_dt = res_dts[i]
1056+
if order == "K":
1057+
out[i] = dtc._empty_like_pair_orderK(
1058+
x1, x2, res_dt, res_shape, res_usm_type, exec_q
1059+
)
1060+
else:
1061+
out[i] = dpt.empty(
1062+
res_shape,
1063+
dtype=res_dt,
1064+
order=order,
1065+
usm_type=res_usm_type,
1066+
sycl_queue=exec_q,
1067+
)
1068+
1069+
# Broadcast shapes of input arrays
1070+
if x1.shape != res_shape:
1071+
x1 = dpt.broadcast_to(x1, res_shape)
1072+
if x2.shape != res_shape:
1073+
x2 = dpt.broadcast_to(x2, res_shape)
1074+
1075+
# Call the binary function with input and output arrays
1076+
ht_binary_ev, binary_ev = self.get_implementation_function()(
1077+
x1,
1078+
x2,
1079+
dpnp.get_usm_ndarray(out[0]),
1080+
dpnp.get_usm_ndarray(out[1]),
1081+
sycl_queue=exec_q,
1082+
depends=_manager.submitted_events,
1083+
)
1084+
_manager.add_event_pair(ht_binary_ev, binary_ev)
1085+
1086+
for i in range(self.nout):
1087+
orig_res, res = orig_out[i], out[i]
1088+
if not (orig_res is None or orig_res is res):
1089+
# Copy the out data from temporary buffer to original memory
1090+
ht_copy_ev, copy_ev = dti._copy_usm_ndarray_into_usm_ndarray(
1091+
src=res,
1092+
dst=dpnp.get_usm_ndarray(orig_res),
1093+
sycl_queue=exec_q,
1094+
depends=[binary_ev],
1095+
)
1096+
_manager.add_event_pair(ht_copy_ev, copy_ev)
1097+
res = out[i] = orig_res
1098+
1099+
if not isinstance(res, dpnp_array):
1100+
# Always return dpnp.ndarray
1101+
out[i] = dpnp_array._create_from_usm_ndarray(res)
1102+
return tuple(out)
1103+
1104+
7981105
class DPNPAngle(DPNPUnaryFunc):
7991106
"""Class that implements dpnp.angle unary element-wise functions."""
8001107

0 commit comments

Comments
 (0)