|
39 | 39 | BinaryElementwiseFunc, |
40 | 40 | UnaryElementwiseFunc, |
41 | 41 | ) |
| 42 | +from dpctl.tensor._scalar_utils import ( |
| 43 | + _get_dtype, |
| 44 | + _get_shape, |
| 45 | + _validate_dtype, |
| 46 | +) |
42 | 47 |
|
43 | 48 | import dpnp |
44 | 49 | import dpnp.backend.extensions.vm._vm_impl as vmi |
45 | 50 | from dpnp.dpnp_array import dpnp_array |
| 51 | +from dpnp.dpnp_utils import get_usm_allocations |
46 | 52 | from dpnp.dpnp_utils.dpnp_utils_common import ( |
47 | 53 | find_buf_dtype_3out, |
| 54 | + find_buf_dtype_4out, |
48 | 55 | ) |
49 | 56 |
|
50 | 57 | __all__ = [ |
51 | 58 | "DPNPI0", |
52 | 59 | "DPNPAngle", |
53 | 60 | "DPNPBinaryFunc", |
54 | 61 | "DPNPBinaryFuncOutKw", |
| 62 | + "DPNPBinaryTwoOutputsFunc", |
55 | 63 | "DPNPFix", |
56 | 64 | "DPNPImag", |
57 | 65 | "DPNPReal", |
@@ -347,7 +355,7 @@ def __call__( |
347 | 355 |
|
348 | 356 | buf_dt, res1_dt, res2_dt = find_buf_dtype_3out( |
349 | 357 | x.dtype, |
350 | | - self.result_type_resolver_fn_, |
| 358 | + self.get_type_result_resolver_function(), |
351 | 359 | x.sycl_device, |
352 | 360 | ) |
353 | 361 | if res1_dt is None or res2_dt is None: |
@@ -444,13 +452,12 @@ def __call__( |
444 | 452 | out[i] = dpt.empty_like(x, dtype=res_dt, order=order) |
445 | 453 |
|
446 | 454 | # Call the unary function with input and output arrays |
447 | | - dep_evs = _manager.submitted_events |
448 | 455 | ht_unary_ev, unary_ev = self.get_implementation_function()( |
449 | 456 | x, |
450 | 457 | dpnp.get_usm_ndarray(out[0]), |
451 | 458 | dpnp.get_usm_ndarray(out[1]), |
452 | 459 | sycl_queue=exec_q, |
453 | | - depends=dep_evs, |
| 460 | + depends=_manager.submitted_events, |
454 | 461 | ) |
455 | 462 | _manager.add_event_pair(ht_unary_ev, unary_ev) |
456 | 463 |
|
@@ -795,6 +802,306 @@ def __call__(self, *args, **kwargs): |
795 | 802 | return super().__call__(*args, **kwargs) |
796 | 803 |
|
797 | 804 |
|
| 805 | +class DPNPBinaryTwoOutputsFunc(BinaryElementwiseFunc): |
| 806 | + """ |
| 807 | + Class that implements unary element-wise functions with two output arrays. |
| 808 | +
|
| 809 | + Parameters |
| 810 | + ---------- |
| 811 | + name : {str} |
| 812 | + Name of the unary function |
| 813 | + result_type_resolver_fn : {callable} |
| 814 | + Function that takes dtype of the input and returns the dtype of |
| 815 | + the result if the implementation functions supports it, or |
| 816 | + returns `None` otherwise. |
| 817 | + unary_dp_impl_fn : {callable} |
| 818 | + Data-parallel implementation function with signature |
| 819 | + `impl_fn(src: usm_ndarray, dst: usm_ndarray, |
| 820 | + sycl_queue: SyclQueue, depends: Optional[List[SyclEvent]])` |
| 821 | + where the `src` is the argument array, `dst` is the |
| 822 | + array to be populated with function values, effectively |
| 823 | + evaluating `dst = func(src)`. |
| 824 | + The `impl_fn` is expected to return a 2-tuple of `SyclEvent`s. |
| 825 | + The first event corresponds to data-management host tasks, |
| 826 | + including lifetime management of argument Python objects to ensure |
| 827 | + that their associated USM allocation is not freed before offloaded |
| 828 | + computational tasks complete execution, while the second event |
| 829 | + corresponds to computational tasks associated with function evaluation. |
| 830 | + docs : {str} |
| 831 | + Documentation string for the unary function. |
| 832 | + mkl_fn_to_call : {None, str} |
| 833 | + Check input arguments to answer if function from OneMKL VM library |
| 834 | + can be used. |
| 835 | + mkl_impl_fn : {None, str} |
| 836 | + Function from OneMKL VM library to call. |
| 837 | +
|
| 838 | + """ |
| 839 | + |
| 840 | + def __init__( |
| 841 | + self, |
| 842 | + name, |
| 843 | + result_type_resolver_fn, |
| 844 | + binary_dp_impl_fn, |
| 845 | + docs, |
| 846 | + ): |
| 847 | + super().__init__( |
| 848 | + name, |
| 849 | + result_type_resolver_fn, |
| 850 | + binary_dp_impl_fn, |
| 851 | + docs, |
| 852 | + ) |
| 853 | + self.__name__ = "DPNPBinaryTwoOutputsFunc" |
| 854 | + |
| 855 | + @property |
| 856 | + def nout(self): |
| 857 | + """Returns the number of arguments treated as outputs.""" |
| 858 | + return 2 |
| 859 | + |
| 860 | + def __call__( |
| 861 | + self, |
| 862 | + x1, |
| 863 | + x2, |
| 864 | + out1=None, |
| 865 | + out2=None, |
| 866 | + /, |
| 867 | + *, |
| 868 | + out=(None, None), |
| 869 | + where=True, |
| 870 | + order="K", |
| 871 | + dtype=None, |
| 872 | + subok=True, |
| 873 | + **kwargs, |
| 874 | + ): |
| 875 | + if kwargs: |
| 876 | + raise NotImplementedError( |
| 877 | + f"Requested function={self.name_} with kwargs={kwargs} " |
| 878 | + "isn't currently supported." |
| 879 | + ) |
| 880 | + elif where is not True: |
| 881 | + raise NotImplementedError( |
| 882 | + f"Requested function={self.name_} with where={where} " |
| 883 | + "isn't currently supported." |
| 884 | + ) |
| 885 | + elif dtype is not None: |
| 886 | + raise NotImplementedError( |
| 887 | + f"Requested function={self.name_} with dtype={dtype} " |
| 888 | + "isn't currently supported." |
| 889 | + ) |
| 890 | + elif subok is not True: |
| 891 | + raise NotImplementedError( |
| 892 | + f"Requested function={self.name_} with subok={subok} " |
| 893 | + "isn't currently supported." |
| 894 | + ) |
| 895 | + |
| 896 | + dpnp.check_supported_arrays_type(x1, x2, scalar_type=True) |
| 897 | + |
| 898 | + if order is None: |
| 899 | + order = "K" |
| 900 | + elif order in "afkcAFKC": |
| 901 | + order = order.upper() |
| 902 | + else: |
| 903 | + raise ValueError( |
| 904 | + "order must be one of 'C', 'F', 'A', or 'K' " f"(got '{order}')" |
| 905 | + ) |
| 906 | + |
| 907 | + res_usm_type, exec_q = get_usm_allocations([x1, x2]) |
| 908 | + x1 = dpnp.get_usm_ndarray_or_scalar(x1) |
| 909 | + x2 = dpnp.get_usm_ndarray_or_scalar(x2) |
| 910 | + |
| 911 | + x1_sh = _get_shape(x1) |
| 912 | + x2_sh = _get_shape(x2) |
| 913 | + try: |
| 914 | + res_shape = dpnp.broadcast_shapes(x1_sh, x2_sh) |
| 915 | + except ValueError: |
| 916 | + raise ValueError( |
| 917 | + "operands could not be broadcast together with shapes " |
| 918 | + f"{x1_sh} and {x2_sh}" |
| 919 | + ) |
| 920 | + |
| 921 | + sycl_dev = exec_q.sycl_device |
| 922 | + x1_dt = _get_dtype(x1, sycl_dev) |
| 923 | + x2_dt = _get_dtype(x2, sycl_dev) |
| 924 | + if not all(_validate_dtype(dt) for dt in [x1_dt, x2_dt]): |
| 925 | + raise ValueError("Operands have unsupported data types") |
| 926 | + |
| 927 | + x1_dt, x2_dt = self.get_array_dtype_scalar_type_resolver_function()( |
| 928 | + x1_dt, x2_dt, sycl_dev |
| 929 | + ) |
| 930 | + |
| 931 | + buf1_dt, buf2_dt, res1_dt, res2_dt = find_buf_dtype_4out( |
| 932 | + x1_dt, |
| 933 | + x2_dt, |
| 934 | + self.get_type_result_resolver_function(), |
| 935 | + sycl_dev, |
| 936 | + ) |
| 937 | + if res1_dt is None or res2_dt is None: |
| 938 | + raise ValueError( |
| 939 | + f"function '{self.name_}' does not support input type " |
| 940 | + f"({x1_dt}, {x2_dt}), " |
| 941 | + "and the input could not be safely coerced to any " |
| 942 | + "supported types according to the casting rule ''safe''." |
| 943 | + ) |
| 944 | + buf_dts = [buf1_dt, buf2_dt] |
| 945 | + |
| 946 | + if not isinstance(out, tuple): |
| 947 | + raise TypeError("'out' must be a tuple of arrays") |
| 948 | + |
| 949 | + if len(out) != self.nout: |
| 950 | + raise ValueError( |
| 951 | + "'out' tuple must have exactly one entry per ufunc output" |
| 952 | + ) |
| 953 | + |
| 954 | + if not (out1 is None and out2 is None): |
| 955 | + if all(res is None for res in out): |
| 956 | + out = (out1, out2) |
| 957 | + else: |
| 958 | + raise TypeError( |
| 959 | + "cannot specify 'out' as both a positional and keyword argument" |
| 960 | + ) |
| 961 | + |
| 962 | + orig_out, out = list(out), list(out) |
| 963 | + res_dts = [res1_dt, res2_dt] |
| 964 | + |
| 965 | + for i in range(self.nout): |
| 966 | + if out[i] is None: |
| 967 | + continue |
| 968 | + |
| 969 | + res = dpnp.get_usm_ndarray(out[i]) |
| 970 | + if not res.flags.writable: |
| 971 | + raise ValueError("output array is read-only") |
| 972 | + |
| 973 | + if res.shape != res_shape: |
| 974 | + raise ValueError( |
| 975 | + "The shape of input and output arrays are inconsistent. " |
| 976 | + f"Expected output shape is {res_shape}, got {res.shape}" |
| 977 | + ) |
| 978 | + |
| 979 | + if dpu.get_execution_queue((exec_q, res.sycl_queue)) is None: |
| 980 | + raise dpnp.exceptions.ExecutionPlacementError( |
| 981 | + "Input and output allocation queues are not compatible" |
| 982 | + ) |
| 983 | + |
| 984 | + res_dt = res_dts[i] |
| 985 | + if res_dt != res.dtype: |
| 986 | + if not dpnp.can_cast(res_dt, res.dtype, casting="same_kind"): |
| 987 | + raise TypeError( |
| 988 | + f"Cannot cast ufunc '{self.name_}' output {i + 1} from " |
| 989 | + f"{res_dt} to {res.dtype} with casting rule 'same_kind'" |
| 990 | + ) |
| 991 | + |
| 992 | + # Allocate a temporary buffer with the required dtype |
| 993 | + out[i] = dpt.empty_like(res, dtype=res_dt) |
| 994 | + else: |
| 995 | + for x, dt in zip([x1, x2], buf_dts): |
| 996 | + if dpnp.isscalar(x): |
| 997 | + pass |
| 998 | + elif dt is not None: |
| 999 | + pass |
| 1000 | + elif not dti._array_overlap(x, res): |
| 1001 | + pass |
| 1002 | + elif dti._same_logical_tensors(x, res): |
| 1003 | + pass |
| 1004 | + |
| 1005 | + # Allocate a temporary buffer to avoid memory overlapping. |
| 1006 | + # Note if `dt` is not None, a temporary copy of `x` will be |
| 1007 | + # created, so the array overlap check isn't needed. |
| 1008 | + out[i] = dpt.empty_like(res) |
| 1009 | + break |
| 1010 | + |
| 1011 | + x1 = dpnp.as_usm_ndarray(x1, dtype=x1_dt, sycl_queue=exec_q) |
| 1012 | + x2 = dpnp.as_usm_ndarray(x2, dtype=x2_dt, sycl_queue=exec_q) |
| 1013 | + |
| 1014 | + if order == "A": |
| 1015 | + if x1.flags.f_contiguous and x2.flags.f_contiguous: |
| 1016 | + order = "F" |
| 1017 | + else: |
| 1018 | + order = "C" |
| 1019 | + |
| 1020 | + _manager = dpu.SequentialOrderManager[exec_q] |
| 1021 | + dep_evs = _manager.submitted_events |
| 1022 | + |
| 1023 | + # Cast input array to the supported type if needed |
| 1024 | + if any(dt is not None for dt in buf_dts): |
| 1025 | + if all(dt is not None for dt in buf_dts): |
| 1026 | + if x1.flags.c_contiguous and x2.flags.c_contiguous: |
| 1027 | + order = "C" |
| 1028 | + elif x1.flags.f_contiguous and x2.flags.f_contiguous: |
| 1029 | + order = "F" |
| 1030 | + |
| 1031 | + arrs = [x1, x2] |
| 1032 | + buf_dts = [buf1_dt, buf2_dt] |
| 1033 | + for i in range(self.nout): |
| 1034 | + buf_dt = buf_dts[i] |
| 1035 | + if buf_dt is None: |
| 1036 | + continue |
| 1037 | + |
| 1038 | + x = arrs[i] |
| 1039 | + if order == "K": |
| 1040 | + buf = dtc._empty_like_orderK(x, buf_dt) |
| 1041 | + else: |
| 1042 | + buf = dpt.empty_like(x, dtype=buf_dt, order=order) |
| 1043 | + |
| 1044 | + ht_copy_ev, copy_ev = dti._copy_usm_ndarray_into_usm_ndarray( |
| 1045 | + src=x, dst=buf, sycl_queue=exec_q, depends=dep_evs |
| 1046 | + ) |
| 1047 | + _manager.add_event_pair(ht_copy_ev, copy_ev) |
| 1048 | + |
| 1049 | + arrs[i] = buf |
| 1050 | + x1, x2 = arrs |
| 1051 | + |
| 1052 | + # Allocate a buffer for the output arrays if needed |
| 1053 | + for i in range(self.nout): |
| 1054 | + if out[i] is None: |
| 1055 | + res_dt = res_dts[i] |
| 1056 | + if order == "K": |
| 1057 | + out[i] = dtc._empty_like_pair_orderK( |
| 1058 | + x1, x2, res_dt, res_shape, res_usm_type, exec_q |
| 1059 | + ) |
| 1060 | + else: |
| 1061 | + out[i] = dpt.empty( |
| 1062 | + res_shape, |
| 1063 | + dtype=res_dt, |
| 1064 | + order=order, |
| 1065 | + usm_type=res_usm_type, |
| 1066 | + sycl_queue=exec_q, |
| 1067 | + ) |
| 1068 | + |
| 1069 | + # Broadcast shapes of input arrays |
| 1070 | + if x1.shape != res_shape: |
| 1071 | + x1 = dpt.broadcast_to(x1, res_shape) |
| 1072 | + if x2.shape != res_shape: |
| 1073 | + x2 = dpt.broadcast_to(x2, res_shape) |
| 1074 | + |
| 1075 | + # Call the binary function with input and output arrays |
| 1076 | + ht_binary_ev, binary_ev = self.get_implementation_function()( |
| 1077 | + x1, |
| 1078 | + x2, |
| 1079 | + dpnp.get_usm_ndarray(out[0]), |
| 1080 | + dpnp.get_usm_ndarray(out[1]), |
| 1081 | + sycl_queue=exec_q, |
| 1082 | + depends=_manager.submitted_events, |
| 1083 | + ) |
| 1084 | + _manager.add_event_pair(ht_binary_ev, binary_ev) |
| 1085 | + |
| 1086 | + for i in range(self.nout): |
| 1087 | + orig_res, res = orig_out[i], out[i] |
| 1088 | + if not (orig_res is None or orig_res is res): |
| 1089 | + # Copy the out data from temporary buffer to original memory |
| 1090 | + ht_copy_ev, copy_ev = dti._copy_usm_ndarray_into_usm_ndarray( |
| 1091 | + src=res, |
| 1092 | + dst=dpnp.get_usm_ndarray(orig_res), |
| 1093 | + sycl_queue=exec_q, |
| 1094 | + depends=[binary_ev], |
| 1095 | + ) |
| 1096 | + _manager.add_event_pair(ht_copy_ev, copy_ev) |
| 1097 | + res = out[i] = orig_res |
| 1098 | + |
| 1099 | + if not isinstance(res, dpnp_array): |
| 1100 | + # Always return dpnp.ndarray |
| 1101 | + out[i] = dpnp_array._create_from_usm_ndarray(res) |
| 1102 | + return tuple(out) |
| 1103 | + |
| 1104 | + |
798 | 1105 | class DPNPAngle(DPNPUnaryFunc): |
799 | 1106 | """Class that implements dpnp.angle unary element-wise functions.""" |
800 | 1107 |
|
|
0 commit comments