diff --git a/pytensor/link/numba/dispatch/linalg/_LAPACK.py b/pytensor/link/numba/dispatch/linalg/_LAPACK.py index 2e3812cdc4..cc23f6a9b7 100644 --- a/pytensor/link/numba/dispatch/linalg/_LAPACK.py +++ b/pytensor/link/numba/dispatch/linalg/_LAPACK.py @@ -1832,3 +1832,206 @@ def hegvd( ) return hegvd + + @classmethod + def numba_xgesvd(cls, dtype) -> CPUDispatcher: + """ + Compute the singular value decomposition of a general M-by-N matrix using the + QR-based algorithm (LAPACK xGESVD). + + Called by scipy.linalg.svd with lapack_driver='gesvd' and numpy.linalg.svd for + the non-divide-and-conquer path. + """ + kind = get_blas_kind(dtype) + float_ptr = _get_nb_float_from_dtype(kind) + is_complex = isinstance(dtype, Complex) + real_ptr = nb_f64p if dtype is nb_c128 else nb_f32p + unique_func_name = f"scipy.lapack.{kind}gesvd" + + @numba_basic.numba_njit + def get_gesvd_pointer(): + with numba.objmode(ptr=types.intp): + ptr = get_lapack_ptr(dtype, "gesvd") + return ptr + + if is_complex: + gesvd_function_type = types.FunctionType( + types.void( + nb_i32p, # JOBU + nb_i32p, # JOBVT + nb_i32p, # M + nb_i32p, # N + float_ptr, # A + nb_i32p, # LDA + real_ptr, # S + float_ptr, # U + nb_i32p, # LDU + float_ptr, # VT + nb_i32p, # LDVT + float_ptr, # WORK + nb_i32p, # LWORK + real_ptr, # RWORK + nb_i32p, # INFO + ) + ) + + @numba_basic.numba_njit + def gesvd( + JOBU, JOBVT, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, RWORK, INFO + ): + fn = _call_cached_ptr( + get_ptr_func=get_gesvd_pointer, + func_type_ref=gesvd_function_type, + unique_func_name_lit=unique_func_name, + ) + fn( + JOBU, + JOBVT, + M, + N, + A, + LDA, + S, + U, + LDU, + VT, + LDVT, + WORK, + LWORK, + RWORK, + INFO, + ) + + else: + gesvd_function_type = types.FunctionType( + types.void( + nb_i32p, # JOBU + nb_i32p, # JOBVT + nb_i32p, # M + nb_i32p, # N + float_ptr, # A + nb_i32p, # LDA + float_ptr, # S + float_ptr, # U + nb_i32p, # LDU + float_ptr, # VT + nb_i32p, # LDVT + float_ptr, # WORK + nb_i32p, # LWORK + nb_i32p, # INFO + ) + ) + + @numba_basic.numba_njit + def gesvd( + JOBU, JOBVT, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, INFO + ): + fn = _call_cached_ptr( + get_ptr_func=get_gesvd_pointer, + func_type_ref=gesvd_function_type, + unique_func_name_lit=unique_func_name, + ) + fn(JOBU, JOBVT, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, INFO) + + return gesvd + + @classmethod + def numba_xgesdd(cls, dtype) -> CPUDispatcher: + """ + Compute the singular value decomposition of a general M-by-N matrix using the + divide-and-conquer algorithm (LAPACK xGESDD). + + Called by scipy.linalg.svd (default driver) and numpy.linalg.svd. + """ + kind = get_blas_kind(dtype) + float_ptr = _get_nb_float_from_dtype(kind) + is_complex = isinstance(dtype, Complex) + real_ptr = nb_f64p if dtype is nb_c128 else nb_f32p + unique_func_name = f"scipy.lapack.{kind}gesdd" + + @numba_basic.numba_njit + def get_gesdd_pointer(): + with numba.objmode(ptr=types.intp): + ptr = get_lapack_ptr(dtype, "gesdd") + return ptr + + if is_complex: + gesdd_function_type = types.FunctionType( + types.void( + nb_i32p, # JOBZ + nb_i32p, # M + nb_i32p, # N + float_ptr, # A + nb_i32p, # LDA + real_ptr, # S + float_ptr, # U + nb_i32p, # LDU + float_ptr, # VT + nb_i32p, # LDVT + float_ptr, # WORK + nb_i32p, # LWORK + real_ptr, # RWORK + nb_i32p, # IWORK + nb_i32p, # INFO + ) + ) + + @numba_basic.numba_njit + def gesdd( + JOBZ, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, RWORK, IWORK, INFO + ): + fn = _call_cached_ptr( + get_ptr_func=get_gesdd_pointer, + func_type_ref=gesdd_function_type, + unique_func_name_lit=unique_func_name, + ) + fn( + JOBZ, + M, + N, + A, + LDA, + S, + U, + LDU, + VT, + LDVT, + WORK, + LWORK, + RWORK, + IWORK, + INFO, + ) + + else: + gesdd_function_type = types.FunctionType( + types.void( + nb_i32p, # JOBZ + nb_i32p, # M + nb_i32p, # N + float_ptr, # A + nb_i32p, # LDA + float_ptr, # S + float_ptr, # U + nb_i32p, # LDU + float_ptr, # VT + nb_i32p, # LDVT + float_ptr, # WORK + nb_i32p, # LWORK + nb_i32p, # IWORK + nb_i32p, # INFO + ) + ) + + @numba_basic.numba_njit + def gesdd( + JOBZ, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, IWORK, INFO + ): + fn = _call_cached_ptr( + get_ptr_func=get_gesdd_pointer, + func_type_ref=gesdd_function_type, + unique_func_name_lit=unique_func_name, + ) + fn(JOBZ, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, IWORK, INFO) + + return gesdd diff --git a/pytensor/link/numba/dispatch/linalg/decomposition/dispatch.py b/pytensor/link/numba/dispatch/linalg/decomposition/dispatch.py index 99c0fca284..239c194fad 100644 --- a/pytensor/link/numba/dispatch/linalg/decomposition/dispatch.py +++ b/pytensor/link/numba/dispatch/linalg/decomposition/dispatch.py @@ -43,6 +43,10 @@ schur_complex, schur_real, ) +from pytensor.link.numba.dispatch.linalg.decomposition.svd import ( + _svd_gesdd_full, + _svd_gesdd_no_uv, +) from pytensor.tensor.linalg.decomposition.cholesky import Cholesky from pytensor.tensor.linalg.decomposition.eigen import Eig, Eigh, Eigvalsh from pytensor.tensor.linalg.decomposition.lu import LU, LUFactor, PivotToPermutations @@ -61,14 +65,14 @@ def numba_funcify_SVD(op, node, **kwargs): if discrete_input and config.compiler_verbose: print("SVD requires casting discrete input to float") # noqa: T201 - # np.linalg.svd always returns real-valued singular values, even for complex input. - # The Op may declare s as complex (matching input dtype), but numba returns the real - # component dtype, so we must match that to avoid type unification errors. + # Casting discrete input to float allocates a new buffer, so in-place is moot. + effective_overwrite_a = op.overwrite_a and not discrete_input + matrix_dtype = out_dtype - if out_dtype.kind == "c": - s_dtype = np.dtype(f"f{out_dtype.itemsize // 2}") - else: - s_dtype = out_dtype + # SVD declares S with the real component dtype via linalg_real_output_dtype, + # so the s output's own dtype is the right answer for both real and complex + # input. + s_dtype = np.dtype(node.outputs[1 if compute_uv else 0].dtype) if not compute_uv: @@ -80,8 +84,7 @@ def svd(x): return np.zeros((k,), dtype=s_dtype) if discrete_input: x = x.astype(out_dtype) - _, ret, _ = np.linalg.svd(x, full_matrices) - return ret + return _svd_gesdd_no_uv(x, overwrite_a=effective_overwrite_a) else: @@ -90,8 +93,8 @@ def svd(x): if x.size == 0: m, n = x.shape k = min(m, n) - # The LAPACK dispatch returns matrices in fortran order. To match this for the empty cases, - # build flip the shape inputs to np.zeros and transpose. + # LAPACK returns matrices in fortran order; build the empty + # returns with reversed shape + transpose to match. if full_matrices: return ( np.zeros((m, m), dtype=matrix_dtype).T, @@ -106,9 +109,13 @@ def svd(x): ) if discrete_input: x = x.astype(out_dtype) - return np.linalg.svd(x, full_matrices) + return _svd_gesdd_full( + x, + full_matrices=full_matrices, + overwrite_a=effective_overwrite_a, + ) - cache_version = 1 + cache_version = 2 return svd, cache_version diff --git a/pytensor/link/numba/dispatch/linalg/decomposition/svd.py b/pytensor/link/numba/dispatch/linalg/decomposition/svd.py new file mode 100644 index 0000000000..d9828c3632 --- /dev/null +++ b/pytensor/link/numba/dispatch/linalg/decomposition/svd.py @@ -0,0 +1,719 @@ +import numpy as np +from numba.core.extending import overload +from numba.core.types import Complex, Float +from numba.np.linalg import _copy_to_fortran_order, ensure_lapack + +from pytensor.link.numba.dispatch.linalg._LAPACK import ( + _LAPACK, + _get_underlying_float, + int_ptr_to_val, + val_to_int_ptr, +) +from pytensor.link.numba.dispatch.linalg.utils import _check_linalg_matrix + + +def _svd_gesvd_full(a, full_matrices=True): + """Placeholder; overloaded with a direct xGESVD dispatch.""" + return np.linalg.svd(a, full_matrices=full_matrices) + + +def _svd_gesvd_no_uv(a): + """Placeholder; overloaded with xGESVD computing singular values only.""" + return np.linalg.svd(a, compute_uv=False) + + +def _svd_gesdd_full(a, full_matrices=True, overwrite_a=False): + """Placeholder; overloaded with a direct xGESDD dispatch.""" + return np.linalg.svd(a, full_matrices=full_matrices) + + +def _svd_gesdd_no_uv(a, overwrite_a=False): + """Placeholder; overloaded with xGESDD computing singular values only.""" + return np.linalg.svd(a, compute_uv=False) + + +@overload(_svd_gesvd_full) +def svd_gesvd_full_impl(A, full_matrices=True): + ensure_lapack() + _check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="svd") + dtype = A.dtype + is_complex = isinstance(dtype, Complex) + real_dtype = _get_underlying_float(dtype) if is_complex else None + numba_gesvd = _LAPACK().numba_xgesvd(dtype) + + if is_complex: + + def impl(A, full_matrices=True): + _M = np.int32(A.shape[0]) + _N = np.int32(A.shape[1]) + _K = min(_M, _N) + + A_copy = _copy_to_fortran_order(A) + + if full_matrices: + JOBU = val_to_int_ptr(ord("A")) + JOBVT = val_to_int_ptr(ord("A")) + U = np.empty((_M, _M), dtype=A.dtype).T + VT = np.empty((_N, _N), dtype=A.dtype).T + else: + JOBU = val_to_int_ptr(ord("S")) + JOBVT = val_to_int_ptr(ord("S")) + U = np.empty((_K, _M), dtype=A.dtype).T + VT = np.empty((_N, _K), dtype=A.dtype).T + + S = np.empty(_K, dtype=real_dtype) + M = val_to_int_ptr(_M) + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(max(np.int32(1), _M)) + LDU = val_to_int_ptr(max(np.int32(1), _M)) + LDVT = val_to_int_ptr(max(np.int32(1), np.int32(VT.shape[0]))) + + RWORK = np.empty(max(np.int32(1), 5 * _K), dtype=real_dtype) + INFO = val_to_int_ptr(0) + + # Workspace query + LWORK = val_to_int_ptr(np.int32(-1)) + WORK = np.empty(1, dtype=A.dtype) + numba_gesvd( + JOBU, + JOBVT, + M, + N, + A_copy.ctypes, + LDA, + S.ctypes, + U.ctypes, + LDU, + VT.ctypes, + LDVT, + WORK.ctypes, + LWORK, + RWORK.ctypes, + INFO, + ) + + lwork = np.int32(WORK[0].real) + WORK = np.empty(lwork, dtype=A.dtype) + LWORK = val_to_int_ptr(lwork) + INFO = val_to_int_ptr(0) + + numba_gesvd( + JOBU, + JOBVT, + M, + N, + A_copy.ctypes, + LDA, + S.ctypes, + U.ctypes, + LDU, + VT.ctypes, + LDVT, + WORK.ctypes, + LWORK, + RWORK.ctypes, + INFO, + ) + + if int_ptr_to_val(INFO) != 0: + S[:] = np.nan + U[:] = np.nan + VT[:] = np.nan + + return U, S, VT + + else: + + def impl(A, full_matrices=True): + _M = np.int32(A.shape[0]) + _N = np.int32(A.shape[1]) + _K = min(_M, _N) + + A_copy = _copy_to_fortran_order(A) + + if full_matrices: + JOBU = val_to_int_ptr(ord("A")) + JOBVT = val_to_int_ptr(ord("A")) + U = np.empty((_M, _M), dtype=A.dtype).T + VT = np.empty((_N, _N), dtype=A.dtype).T + else: + JOBU = val_to_int_ptr(ord("S")) + JOBVT = val_to_int_ptr(ord("S")) + U = np.empty((_K, _M), dtype=A.dtype).T + VT = np.empty((_N, _K), dtype=A.dtype).T + + S = np.empty(_K, dtype=A.dtype) + M = val_to_int_ptr(_M) + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(max(np.int32(1), _M)) + LDU = val_to_int_ptr(max(np.int32(1), _M)) + LDVT = val_to_int_ptr(max(np.int32(1), np.int32(VT.shape[0]))) + INFO = val_to_int_ptr(0) + + LWORK = val_to_int_ptr(np.int32(-1)) + WORK = np.empty(1, dtype=A.dtype) + numba_gesvd( + JOBU, + JOBVT, + M, + N, + A_copy.ctypes, + LDA, + S.ctypes, + U.ctypes, + LDU, + VT.ctypes, + LDVT, + WORK.ctypes, + LWORK, + INFO, + ) + + lwork = np.int32(WORK[0]) + WORK = np.empty(lwork, dtype=A.dtype) + LWORK = val_to_int_ptr(lwork) + INFO = val_to_int_ptr(0) + + numba_gesvd( + JOBU, + JOBVT, + M, + N, + A_copy.ctypes, + LDA, + S.ctypes, + U.ctypes, + LDU, + VT.ctypes, + LDVT, + WORK.ctypes, + LWORK, + INFO, + ) + + if int_ptr_to_val(INFO) != 0: + S[:] = np.nan + U[:] = np.nan + VT[:] = np.nan + + return U, S, VT + + return impl + + +@overload(_svd_gesvd_no_uv) +def svd_gesvd_no_uv_impl(A): + ensure_lapack() + _check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="svd") + dtype = A.dtype + is_complex = isinstance(dtype, Complex) + real_dtype = _get_underlying_float(dtype) if is_complex else None + numba_gesvd = _LAPACK().numba_xgesvd(dtype) + + if is_complex: + + def impl(A): + _M = np.int32(A.shape[0]) + _N = np.int32(A.shape[1]) + _K = min(_M, _N) + + A_copy = _copy_to_fortran_order(A) + + JOBU = val_to_int_ptr(ord("N")) + JOBVT = val_to_int_ptr(ord("N")) + # JOBU='N' / JOBVT='N': U and VT are not referenced but still need + # valid pointers and LDU/LDVT >= 1. + U = np.empty(1, dtype=A.dtype) + VT = np.empty(1, dtype=A.dtype) + S = np.empty(_K, dtype=real_dtype) + M = val_to_int_ptr(_M) + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(max(np.int32(1), _M)) + LDU = val_to_int_ptr(np.int32(1)) + LDVT = val_to_int_ptr(np.int32(1)) + RWORK = np.empty(max(np.int32(1), 5 * _K), dtype=real_dtype) + INFO = val_to_int_ptr(0) + + LWORK = val_to_int_ptr(np.int32(-1)) + WORK = np.empty(1, dtype=A.dtype) + numba_gesvd( + JOBU, + JOBVT, + M, + N, + A_copy.ctypes, + LDA, + S.ctypes, + U.ctypes, + LDU, + VT.ctypes, + LDVT, + WORK.ctypes, + LWORK, + RWORK.ctypes, + INFO, + ) + + lwork = np.int32(WORK[0].real) + WORK = np.empty(lwork, dtype=A.dtype) + LWORK = val_to_int_ptr(lwork) + INFO = val_to_int_ptr(0) + + numba_gesvd( + JOBU, + JOBVT, + M, + N, + A_copy.ctypes, + LDA, + S.ctypes, + U.ctypes, + LDU, + VT.ctypes, + LDVT, + WORK.ctypes, + LWORK, + RWORK.ctypes, + INFO, + ) + + if int_ptr_to_val(INFO) != 0: + S[:] = np.nan + + return S + + else: + + def impl(A): + _M = np.int32(A.shape[0]) + _N = np.int32(A.shape[1]) + _K = min(_M, _N) + + A_copy = _copy_to_fortran_order(A) + + JOBU = val_to_int_ptr(ord("N")) + JOBVT = val_to_int_ptr(ord("N")) + U = np.empty(1, dtype=A.dtype) + VT = np.empty(1, dtype=A.dtype) + S = np.empty(_K, dtype=A.dtype) + M = val_to_int_ptr(_M) + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(max(np.int32(1), _M)) + LDU = val_to_int_ptr(np.int32(1)) + LDVT = val_to_int_ptr(np.int32(1)) + INFO = val_to_int_ptr(0) + + LWORK = val_to_int_ptr(np.int32(-1)) + WORK = np.empty(1, dtype=A.dtype) + numba_gesvd( + JOBU, + JOBVT, + M, + N, + A_copy.ctypes, + LDA, + S.ctypes, + U.ctypes, + LDU, + VT.ctypes, + LDVT, + WORK.ctypes, + LWORK, + INFO, + ) + + lwork = np.int32(WORK[0]) + WORK = np.empty(lwork, dtype=A.dtype) + LWORK = val_to_int_ptr(lwork) + INFO = val_to_int_ptr(0) + + numba_gesvd( + JOBU, + JOBVT, + M, + N, + A_copy.ctypes, + LDA, + S.ctypes, + U.ctypes, + LDU, + VT.ctypes, + LDVT, + WORK.ctypes, + LWORK, + INFO, + ) + + if int_ptr_to_val(INFO) != 0: + S[:] = np.nan + + return S + + return impl + + +@overload(_svd_gesdd_full) +def svd_gesdd_full_impl(A, full_matrices=True, overwrite_a=False): + ensure_lapack() + _check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="svd") + dtype = A.dtype + is_complex = isinstance(dtype, Complex) + real_dtype = _get_underlying_float(dtype) if is_complex else None + numba_gesdd = _LAPACK().numba_xgesdd(dtype) + + if is_complex: + + def impl(A, full_matrices=True, overwrite_a=False): + _M = np.int32(A.shape[0]) + _N = np.int32(A.shape[1]) + _K = min(_M, _N) + _MAX = max(_M, _N) + + # gesdd uses A as scratch and clobbers it regardless of JOBZ, so + # when the caller donates an f-contig A we reuse the buffer; the + # post-call contents are meaningless but the alloc is saved. + if overwrite_a and A.flags.f_contiguous: + A_copy = A + else: + A_copy = _copy_to_fortran_order(A) + + if full_matrices: + JOBZ = val_to_int_ptr(ord("A")) + U = np.empty((_M, _M), dtype=A.dtype).T + VT = np.empty((_N, _N), dtype=A.dtype).T + else: + JOBZ = val_to_int_ptr(ord("S")) + U = np.empty((_K, _M), dtype=A.dtype).T + VT = np.empty((_N, _K), dtype=A.dtype).T + + S = np.empty(_K, dtype=real_dtype) + M = val_to_int_ptr(_M) + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(max(np.int32(1), _M)) + LDU = val_to_int_ptr(max(np.int32(1), _M)) + LDVT = val_to_int_ptr(max(np.int32(1), np.int32(VT.shape[0]))) + + # gesdd RWORK sizing (complex, JOBZ != 'N'): LAPACK doc minimum is + # max(1, mn * max(5*mn + 7, 2*mx + 2*mn + 1)) + # gesdd has no LRWORK argument and the WORK query does not return + # an RWORK size, so the formula is the only way to size RWORK. + lrwork = np.int32( + max( + np.int32(1), + max( + 5 * _K * _K + 7 * _K, + 2 * _MAX * _K + 2 * _K * _K + _K, + ), + ) + ) + RWORK = np.empty(lrwork, dtype=real_dtype) + IWORK = np.empty(max(np.int32(1), 8 * _K), dtype=np.int32) + INFO = val_to_int_ptr(0) + + LWORK = val_to_int_ptr(np.int32(-1)) + WORK = np.empty(1, dtype=A.dtype) + numba_gesdd( + JOBZ, + M, + N, + A_copy.ctypes, + LDA, + S.ctypes, + U.ctypes, + LDU, + VT.ctypes, + LDVT, + WORK.ctypes, + LWORK, + RWORK.ctypes, + IWORK.ctypes, + INFO, + ) + + lwork = np.int32(WORK[0].real) + WORK = np.empty(lwork, dtype=A.dtype) + LWORK = val_to_int_ptr(lwork) + INFO = val_to_int_ptr(0) + + numba_gesdd( + JOBZ, + M, + N, + A_copy.ctypes, + LDA, + S.ctypes, + U.ctypes, + LDU, + VT.ctypes, + LDVT, + WORK.ctypes, + LWORK, + RWORK.ctypes, + IWORK.ctypes, + INFO, + ) + + if int_ptr_to_val(INFO) != 0: + S[:] = np.nan + U[:] = np.nan + VT[:] = np.nan + + return U, S, VT + + else: + + def impl(A, full_matrices=True, overwrite_a=False): + _M = np.int32(A.shape[0]) + _N = np.int32(A.shape[1]) + _K = min(_M, _N) + + # Real A: a c-contiguous buffer reinterpreted as fortran-order is + # A.T (M and N swapped in LAPACK's view). For SVD this lets us + # solve the swapped (N, M) problem and recover A's SVD by swapping + # roles: if A = U S Vt, then A.T = V S U.T, so LAPACK's U' maps to + # Vt.T and LAPACK's Vt' maps to U.T. + swap = False + if overwrite_a and A.flags.f_contiguous: + A_copy = A + elif overwrite_a and A.flags.c_contiguous: + A_copy = A.T + swap = True + _M, _N = _N, _M + else: + A_copy = _copy_to_fortran_order(A) + + if full_matrices: + JOBZ = val_to_int_ptr(ord("A")) + U = np.empty((_M, _M), dtype=A.dtype).T + VT = np.empty((_N, _N), dtype=A.dtype).T + else: + JOBZ = val_to_int_ptr(ord("S")) + U = np.empty((_K, _M), dtype=A.dtype).T + VT = np.empty((_N, _K), dtype=A.dtype).T + + S = np.empty(_K, dtype=A.dtype) + M = val_to_int_ptr(_M) + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(max(np.int32(1), _M)) + LDU = val_to_int_ptr(max(np.int32(1), _M)) + LDVT = val_to_int_ptr(max(np.int32(1), np.int32(VT.shape[0]))) + IWORK = np.empty(max(np.int32(1), 8 * _K), dtype=np.int32) + INFO = val_to_int_ptr(0) + + LWORK = val_to_int_ptr(np.int32(-1)) + WORK = np.empty(1, dtype=A.dtype) + numba_gesdd( + JOBZ, + M, + N, + A_copy.ctypes, + LDA, + S.ctypes, + U.ctypes, + LDU, + VT.ctypes, + LDVT, + WORK.ctypes, + LWORK, + IWORK.ctypes, + INFO, + ) + + lwork = np.int32(WORK[0]) + WORK = np.empty(lwork, dtype=A.dtype) + LWORK = val_to_int_ptr(lwork) + INFO = val_to_int_ptr(0) + + numba_gesdd( + JOBZ, + M, + N, + A_copy.ctypes, + LDA, + S.ctypes, + U.ctypes, + LDU, + VT.ctypes, + LDVT, + WORK.ctypes, + LWORK, + IWORK.ctypes, + INFO, + ) + + if int_ptr_to_val(INFO) != 0: + S[:] = np.nan + U[:] = np.nan + VT[:] = np.nan + + if swap: + # Solved SVD of A.T; map back: A's U = (LAPACK's VT).T, + # A's VT = (LAPACK's U).T. The .T's are zero-cost stride swaps. + return VT.T, S, U.T + return U, S, VT + + return impl + + +@overload(_svd_gesdd_no_uv) +def svd_gesdd_no_uv_impl(A, overwrite_a=False): + ensure_lapack() + _check_linalg_matrix(A, ndim=2, dtype=(Float, Complex), func_name="svd") + dtype = A.dtype + is_complex = isinstance(dtype, Complex) + real_dtype = _get_underlying_float(dtype) if is_complex else None + numba_gesdd = _LAPACK().numba_xgesdd(dtype) + + if is_complex: + + def impl(A, overwrite_a=False): + _M = np.int32(A.shape[0]) + _N = np.int32(A.shape[1]) + _K = min(_M, _N) + + if overwrite_a and A.flags.f_contiguous: + A_copy = A + else: + A_copy = _copy_to_fortran_order(A) + + JOBZ = val_to_int_ptr(ord("N")) + U = np.empty(1, dtype=A.dtype) + VT = np.empty(1, dtype=A.dtype) + S = np.empty(_K, dtype=real_dtype) + M = val_to_int_ptr(_M) + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(max(np.int32(1), _M)) + LDU = val_to_int_ptr(np.int32(1)) + LDVT = val_to_int_ptr(np.int32(1)) + + # gesdd RWORK sizing for JOBZ='N'. + lrwork = max(np.int32(1), 7 * _K) + RWORK = np.empty(lrwork, dtype=real_dtype) + IWORK = np.empty(max(np.int32(1), 8 * _K), dtype=np.int32) + INFO = val_to_int_ptr(0) + + LWORK = val_to_int_ptr(np.int32(-1)) + WORK = np.empty(1, dtype=A.dtype) + numba_gesdd( + JOBZ, + M, + N, + A_copy.ctypes, + LDA, + S.ctypes, + U.ctypes, + LDU, + VT.ctypes, + LDVT, + WORK.ctypes, + LWORK, + RWORK.ctypes, + IWORK.ctypes, + INFO, + ) + + lwork = np.int32(WORK[0].real) + WORK = np.empty(lwork, dtype=A.dtype) + LWORK = val_to_int_ptr(lwork) + INFO = val_to_int_ptr(0) + + numba_gesdd( + JOBZ, + M, + N, + A_copy.ctypes, + LDA, + S.ctypes, + U.ctypes, + LDU, + VT.ctypes, + LDVT, + WORK.ctypes, + LWORK, + RWORK.ctypes, + IWORK.ctypes, + INFO, + ) + + if int_ptr_to_val(INFO) != 0: + S[:] = np.nan + + return S + + else: + + def impl(A, overwrite_a=False): + _M = np.int32(A.shape[0]) + _N = np.int32(A.shape[1]) + _K = min(_M, _N) + + # Real, JOBZ='N': singular values of A.T equal those of A, so the + # c-contig-as-f-contig reinterpretation needs no fix-up beyond + # swapping M and N in the LAPACK call. + if overwrite_a and A.flags.f_contiguous: + A_copy = A + elif overwrite_a and A.flags.c_contiguous: + A_copy = A.T + _M, _N = _N, _M + else: + A_copy = _copy_to_fortran_order(A) + + JOBZ = val_to_int_ptr(ord("N")) + U = np.empty(1, dtype=A.dtype) + VT = np.empty(1, dtype=A.dtype) + S = np.empty(_K, dtype=A.dtype) + M = val_to_int_ptr(_M) + N = val_to_int_ptr(_N) + LDA = val_to_int_ptr(max(np.int32(1), _M)) + LDU = val_to_int_ptr(np.int32(1)) + LDVT = val_to_int_ptr(np.int32(1)) + IWORK = np.empty(max(np.int32(1), 8 * _K), dtype=np.int32) + INFO = val_to_int_ptr(0) + + LWORK = val_to_int_ptr(np.int32(-1)) + WORK = np.empty(1, dtype=A.dtype) + numba_gesdd( + JOBZ, + M, + N, + A_copy.ctypes, + LDA, + S.ctypes, + U.ctypes, + LDU, + VT.ctypes, + LDVT, + WORK.ctypes, + LWORK, + IWORK.ctypes, + INFO, + ) + + lwork = np.int32(WORK[0]) + WORK = np.empty(lwork, dtype=A.dtype) + LWORK = val_to_int_ptr(lwork) + INFO = val_to_int_ptr(0) + + numba_gesdd( + JOBZ, + M, + N, + A_copy.ctypes, + LDA, + S.ctypes, + U.ctypes, + LDU, + VT.ctypes, + LDVT, + WORK.ctypes, + LWORK, + IWORK.ctypes, + INFO, + ) + + if int_ptr_to_val(INFO) != 0: + S[:] = np.nan + + return S + + return impl diff --git a/pytensor/tensor/linalg/decomposition/svd.py b/pytensor/tensor/linalg/decomposition/svd.py index 86291b5104..584eae1419 100644 --- a/pytensor/tensor/linalg/decomposition/svd.py +++ b/pytensor/tensor/linalg/decomposition/svd.py @@ -30,15 +30,25 @@ class SVD(Op): compute_uv : bool, optional Whether or not to compute u and v in addition to s. True by default. + overwrite_a : bool, optional + Permit the input matrix to be destroyed during the computation, + saving an allocation. The rewriter promotes this when the input + buffer is otherwise unused. Default False. """ # See doc in the docstring of the function just after this class. - __props__ = ("full_matrices", "compute_uv") + __props__ = ("full_matrices", "compute_uv", "overwrite_a") - def __init__(self, full_matrices: bool = True, compute_uv: bool = True): + def __init__( + self, + full_matrices: bool = True, + compute_uv: bool = True, + overwrite_a: bool = False, + ): self.full_matrices = bool(full_matrices) self.compute_uv = bool(compute_uv) + self.overwrite_a = bool(overwrite_a) if self.compute_uv: if self.full_matrices: self.gufunc_signature = "(m,n)->(m,m),(k),(n,n)" @@ -46,6 +56,15 @@ def __init__(self, full_matrices: bool = True, compute_uv: bool = True): self.gufunc_signature = "(m,n)->(m,k),(k),(k,n)" else: self.gufunc_signature = "(m,n)->(k)" + if self.overwrite_a: + self.destroy_map = {0: [0]} + + def inplace_on_inputs(self, allowed_inplace_inputs: list[int]) -> "Op": + if not allowed_inplace_inputs: + return self + new_props = self._props_dict() # type: ignore + new_props["overwrite_a"] = True + return type(self)(**new_props) def make_node(self, x): x = as_tensor_variable(x) diff --git a/tests/link/numba/linalg/test_decomposition.py b/tests/link/numba/linalg/test_decomposition.py index deb84576df..96943cb290 100644 --- a/tests/link/numba/linalg/test_decomposition.py +++ b/tests/link/numba/linalg/test_decomposition.py @@ -139,6 +139,80 @@ def test_SVD(x, full_matrices, compute_uv): compare_numba_and_py([x], g, [test_x]) +@pytest.mark.parametrize( + "full_matrices, compute_uv", + [(True, True), (False, True), (True, False)], + ids=["full_uv", "econ_uv", "no_uv"], +) +@pytest.mark.parametrize( + "overwrite_a", [False, True], ids=["no_overwrite", "overwrite_a"] +) +@pytest.mark.parametrize("is_complex", [False, True], ids=["real", "complex"]) +def test_SVD_inplace(full_matrices, compute_uv, overwrite_a, is_complex): + complex_dtype = "complex64" if floatX.endswith("32") else "complex128" + dtype = complex_dtype if is_complex else floatX + + x = pt.matrix("x", dtype=dtype) + op = svd.SVD( + full_matrices=full_matrices, + compute_uv=compute_uv, + overwrite_a=overwrite_a, + ) + outs = op(x) + out_list = list(outs) if compute_uv else [outs] + + assert op.destroy_map == ({0: [0]} if overwrite_a else {}) + + fn = pytensor.function( + [In(x, mutable=overwrite_a)], + out_list, + mode=numba_inplace_mode, + accept_inplace=True, + ) + + fn_op = fn.maker.fgraph.outputs[0].owner.op + core_op = fn_op.core_op if isinstance(fn_op, pt.blockwise.Blockwise) else fn_op + assert isinstance(core_op, svd.SVD) + assert core_op.destroy_map == ({0: [0]} if overwrite_a else {}) + + local_rng = np.random.default_rng(0) + val = local_rng.normal(size=(4, 4)).astype(floatX) + if is_complex: + val = (val + 1j * local_rng.normal(size=(4, 4)).astype(floatX)).astype(dtype) + + ref_s = np.linalg.svd(val, full_matrices=full_matrices, compute_uv=False) + + # gesdd reuses the buffer when the donated array is f-contig (always) or + # c-contig (real-typed only — complex would need a conjugation pass). + can_reuse_c = overwrite_a and not is_complex + + def check(layout_val, expect_mutation): + snapshot = np.array(layout_val) + out = fn(layout_val) + s = out[1] if compute_uv else out[0] + np.testing.assert_allclose(s, ref_s, atol=1e-4, rtol=1e-4) + if compute_uv: + U, _, Vt = out + if full_matrices: + # Pad s into the rectangular shape so we can reconstruct. + k = s.shape[0] + m, n = val.shape + S_mat = np.zeros((m, n), dtype=U.dtype) + S_mat[:k, :k] = np.diag(s.astype(U.dtype)) + recon = U @ S_mat @ Vt + else: + recon = (U * s.astype(U.dtype)) @ Vt + np.testing.assert_allclose(recon, val, atol=5e-4, rtol=5e-4) + if expect_mutation: + assert not np.allclose(layout_val, snapshot) + else: + np.testing.assert_allclose(layout_val, snapshot) + + check(np.copy(val, order="F"), expect_mutation=overwrite_a) + check(np.copy(val, order="C"), expect_mutation=can_reuse_c) + check(np.repeat(val, 2, axis=0)[::2], expect_mutation=False) + + class TestDecompositions: @pytest.mark.parametrize("lower", [True, False], ids=lambda x: f"lower={x}") @pytest.mark.parametrize(