Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 203 additions & 0 deletions pytensor/link/numba/dispatch/linalg/_LAPACK.py
Original file line number Diff line number Diff line change
Expand Up @@ -1832,3 +1832,206 @@ def hegvd(
)

return hegvd

@classmethod
def numba_xgesvd(cls, dtype) -> CPUDispatcher:
"""
Compute the singular value decomposition of a general M-by-N matrix using the
QR-based algorithm (LAPACK xGESVD).

Called by scipy.linalg.svd with lapack_driver='gesvd' and numpy.linalg.svd for
the non-divide-and-conquer path.
"""
kind = get_blas_kind(dtype)
float_ptr = _get_nb_float_from_dtype(kind)
is_complex = isinstance(dtype, Complex)
real_ptr = nb_f64p if dtype is nb_c128 else nb_f32p
unique_func_name = f"scipy.lapack.{kind}gesvd"

@numba_basic.numba_njit
def get_gesvd_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "gesvd")
return ptr

if is_complex:
gesvd_function_type = types.FunctionType(
types.void(
nb_i32p, # JOBU
nb_i32p, # JOBVT
nb_i32p, # M
nb_i32p, # N
float_ptr, # A
nb_i32p, # LDA
real_ptr, # S
float_ptr, # U
nb_i32p, # LDU
float_ptr, # VT
nb_i32p, # LDVT
float_ptr, # WORK
nb_i32p, # LWORK
real_ptr, # RWORK
nb_i32p, # INFO
)
)

@numba_basic.numba_njit
def gesvd(
JOBU, JOBVT, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, RWORK, INFO
):
fn = _call_cached_ptr(
get_ptr_func=get_gesvd_pointer,
func_type_ref=gesvd_function_type,
unique_func_name_lit=unique_func_name,
)
fn(
JOBU,
JOBVT,
M,
N,
A,
LDA,
S,
U,
LDU,
VT,
LDVT,
WORK,
LWORK,
RWORK,
INFO,
)

else:
gesvd_function_type = types.FunctionType(
types.void(
nb_i32p, # JOBU
nb_i32p, # JOBVT
nb_i32p, # M
nb_i32p, # N
float_ptr, # A
nb_i32p, # LDA
float_ptr, # S
float_ptr, # U
nb_i32p, # LDU
float_ptr, # VT
nb_i32p, # LDVT
float_ptr, # WORK
nb_i32p, # LWORK
nb_i32p, # INFO
)
)

@numba_basic.numba_njit
def gesvd(
JOBU, JOBVT, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, INFO
):
fn = _call_cached_ptr(
get_ptr_func=get_gesvd_pointer,
func_type_ref=gesvd_function_type,
unique_func_name_lit=unique_func_name,
)
fn(JOBU, JOBVT, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, INFO)

return gesvd

@classmethod
def numba_xgesdd(cls, dtype) -> CPUDispatcher:
"""
Compute the singular value decomposition of a general M-by-N matrix using the
divide-and-conquer algorithm (LAPACK xGESDD).

Called by scipy.linalg.svd (default driver) and numpy.linalg.svd.
"""
kind = get_blas_kind(dtype)
float_ptr = _get_nb_float_from_dtype(kind)
is_complex = isinstance(dtype, Complex)
real_ptr = nb_f64p if dtype is nb_c128 else nb_f32p
unique_func_name = f"scipy.lapack.{kind}gesdd"

@numba_basic.numba_njit
def get_gesdd_pointer():
with numba.objmode(ptr=types.intp):
ptr = get_lapack_ptr(dtype, "gesdd")
return ptr

if is_complex:
gesdd_function_type = types.FunctionType(
types.void(
nb_i32p, # JOBZ
nb_i32p, # M
nb_i32p, # N
float_ptr, # A
nb_i32p, # LDA
real_ptr, # S
float_ptr, # U
nb_i32p, # LDU
float_ptr, # VT
nb_i32p, # LDVT
float_ptr, # WORK
nb_i32p, # LWORK
real_ptr, # RWORK
nb_i32p, # IWORK
nb_i32p, # INFO
)
)

@numba_basic.numba_njit
def gesdd(
JOBZ, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, RWORK, IWORK, INFO
):
fn = _call_cached_ptr(
get_ptr_func=get_gesdd_pointer,
func_type_ref=gesdd_function_type,
unique_func_name_lit=unique_func_name,
)
fn(
JOBZ,
M,
N,
A,
LDA,
S,
U,
LDU,
VT,
LDVT,
WORK,
LWORK,
RWORK,
IWORK,
INFO,
)

else:
gesdd_function_type = types.FunctionType(
types.void(
nb_i32p, # JOBZ
nb_i32p, # M
nb_i32p, # N
float_ptr, # A
nb_i32p, # LDA
float_ptr, # S
float_ptr, # U
nb_i32p, # LDU
float_ptr, # VT
nb_i32p, # LDVT
float_ptr, # WORK
nb_i32p, # LWORK
nb_i32p, # IWORK
nb_i32p, # INFO
)
)

@numba_basic.numba_njit
def gesdd(
JOBZ, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, IWORK, INFO
):
fn = _call_cached_ptr(
get_ptr_func=get_gesdd_pointer,
func_type_ref=gesdd_function_type,
unique_func_name_lit=unique_func_name,
)
fn(JOBZ, M, N, A, LDA, S, U, LDU, VT, LDVT, WORK, LWORK, IWORK, INFO)

return gesdd
33 changes: 20 additions & 13 deletions pytensor/link/numba/dispatch/linalg/decomposition/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@
schur_complex,
schur_real,
)
from pytensor.link.numba.dispatch.linalg.decomposition.svd import (
_svd_gesdd_full,
_svd_gesdd_no_uv,
)
from pytensor.tensor.linalg.decomposition.cholesky import Cholesky
from pytensor.tensor.linalg.decomposition.eigen import Eig, Eigh, Eigvalsh
from pytensor.tensor.linalg.decomposition.lu import LU, LUFactor, PivotToPermutations
Expand All @@ -61,14 +65,14 @@ def numba_funcify_SVD(op, node, **kwargs):
if discrete_input and config.compiler_verbose:
print("SVD requires casting discrete input to float") # noqa: T201

# np.linalg.svd always returns real-valued singular values, even for complex input.
# The Op may declare s as complex (matching input dtype), but numba returns the real
# component dtype, so we must match that to avoid type unification errors.
# Casting discrete input to float allocates a new buffer, so in-place is moot.
Copy link
Copy Markdown
Member

@ricardoV94 ricardoV94 May 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hum? if you skip overwrite a in integer inputs you do two copies? the one to float and the one for the output?

you could say that if you cast you can always in place because it's a fresh buffer? opposite of what you did.

(does astype accept order=F). if not we should maybe implement our own version that does.

Actually I've missed this opt in other places?

For another time, should these cast be part of the graph? like expand dims that elemwise adds? then they 1) could be fused or rendered useless or what have you and 2) the numba dispatchers never need to worry about it

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let me check

effective_overwrite_a = op.overwrite_a and not discrete_input

matrix_dtype = out_dtype
if out_dtype.kind == "c":
s_dtype = np.dtype(f"f{out_dtype.itemsize // 2}")
else:
s_dtype = out_dtype
# SVD declares S with the real component dtype via linalg_real_output_dtype,
# so the s output's own dtype is the right answer for both real and complex
# input.
s_dtype = np.dtype(node.outputs[1 if compute_uv else 0].dtype)

if not compute_uv:

Expand All @@ -80,8 +84,7 @@ def svd(x):
return np.zeros((k,), dtype=s_dtype)
if discrete_input:
x = x.astype(out_dtype)
_, ret, _ = np.linalg.svd(x, full_matrices)
return ret
return _svd_gesdd_no_uv(x, overwrite_a=effective_overwrite_a)

else:

Expand All @@ -90,8 +93,8 @@ def svd(x):
if x.size == 0:
m, n = x.shape
k = min(m, n)
# The LAPACK dispatch returns matrices in fortran order. To match this for the empty cases,
# build flip the shape inputs to np.zeros and transpose.
# LAPACK returns matrices in fortran order; build the empty
# returns with reversed shape + transpose to match.
if full_matrices:
return (
np.zeros((m, m), dtype=matrix_dtype).T,
Expand All @@ -106,9 +109,13 @@ def svd(x):
)
if discrete_input:
x = x.astype(out_dtype)
return np.linalg.svd(x, full_matrices)
return _svd_gesdd_full(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

worth a (one time) bench we are doing better than the numba dispatch?

x,
full_matrices=full_matrices,
overwrite_a=effective_overwrite_a,
)

cache_version = 1
cache_version = 2
return svd, cache_version


Expand Down
Loading
Loading