From 534b4fb0abf1b121bc3af06cedf3a6cd30df8f5e Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 12 Apr 2026 17:42:08 -0500 Subject: [PATCH 01/13] Refactor blas.py into a module, split out classes by Op --- pytensor/tensor/blas.py | 1746 ------------------------------ pytensor/tensor/blas/__init__.py | 145 +++ pytensor/tensor/blas/_core.py | 190 ++++ pytensor/tensor/blas/batched.py | 460 ++++++++ pytensor/tensor/blas/gemm.py | 860 +++++++++++++++ pytensor/tensor/blas/gemv.py | 117 ++ pytensor/tensor/blas/ger.py | 77 ++ 7 files changed, 1849 insertions(+), 1746 deletions(-) delete mode 100644 pytensor/tensor/blas.py create mode 100644 pytensor/tensor/blas/__init__.py create mode 100644 pytensor/tensor/blas/_core.py create mode 100644 pytensor/tensor/blas/batched.py create mode 100644 pytensor/tensor/blas/gemm.py create mode 100644 pytensor/tensor/blas/gemv.py create mode 100644 pytensor/tensor/blas/ger.py diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py deleted file mode 100644 index 247034451b..0000000000 --- a/pytensor/tensor/blas.py +++ /dev/null @@ -1,1746 +0,0 @@ -"""Ops for using BLAS calls - -BLAS = Basic Linear Algebra Subroutines -Learn more about BLAS here: - http://www.netlib.org/blas/blast-forum/ -The standard BLAS libraries implement what is called "legacy BLAS" in that -document. - -This documentation describes PyTensor's BLAS optimization pipeline. - -Where there is a discrepancy between how things do work and how they *should* -work, both aspects should be documented. - -There are four kinds of BLAS Ops in PyTensor: - - Python implementations (this file) - - SciPy-based (blas_scipy) - - C-based (blas_c) - -Notes ------ -Unfortunately (because it's confusing) this file currently contains Ops -that contain both Python and C versions. I think it would be better to -move the C implementations to blas_c so that this file is pure Python. --JB - - -Ops -=== - -GEMM: Dot22, Dot22Scalar, GemmRelated, Gemm -------------------------------------------- - -The BLAS GEMM operation implements Z <- a X Y + b Z, -where Z, X and Y are matrices, and a and b are scalars. - -Dot22 is a GEMM where a=1, b=0, and Z is allocated every time. - -Dot22Scalar is a GEMM where b=0 and Z is allocated every time. - -Gemm is a GEMM in all its generality. - -In the future we can refactor the GemmRelated, Gemm, Dot22 and -Dot22Scalar Ops into a single Op. That new Op (Gemm2) is basically a -normal Gemm, but with an additional configuration variable that says -to ignore the input Z. Setting that configuration variable to True -would make Gemm2 equivalent to the current Dot22 and Dot22Scalar. -This would make the file a lot easier to read, and save a few hundred -lines of library, to say nothing of testing and documentation. - - -GEMV: Gemv ----------- - -The BLAS GEMV operation implements Z <- a X Y + b Z, -where X is a matrix, Y, and Z are vectors, and a and b are scalars. - - -GER: Ger --------- - -The BLAS GER operation implements Z <- a X' Y + Z, -where X and Y are vectors, and matrix Z gets a rank-1 update. - - -Other Notable BLAS-related Ops ------------------------------- - -SYRK is another useful special case of GEMM. Particularly SYRK preserves -symmetry in the matrix that it updates. See how the linear-algebra module uses -symmetry hints before implementing this Op, so that this Op is compatible with -that system. - - -Optimizations associated with these BLAS Ops are in tensor.rewriting.blas - -""" - -import functools -import logging -import shlex -import warnings -from pathlib import Path - -import numpy as np -from numpy.lib.array_utils import normalize_axis_tuple -from scipy.linalg import get_blas_funcs - -from pytensor.graph import Variable, vectorize_graph - - -try: - import numpy.__config__ -except ImportError: - pass - - -import pytensor.scalar -from pytensor.configdefaults import config -from pytensor.gradient import DisconnectedType, disconnected_type -from pytensor.graph.basic import Apply -from pytensor.graph.op import Op -from pytensor.graph.utils import InconsistencyError, MethodNotDefined -from pytensor.link.c.op import COp -from pytensor.link.c.params_type import ParamsType -from pytensor.printing import FunctionPrinter, pprint -from pytensor.scalar import bool as bool_t -from pytensor.tensor.basic import as_tensor_variable, cast -from pytensor.tensor.blas_headers import blas_header_text, blas_header_version -from pytensor.tensor.math import dot, tensordot -from pytensor.tensor.shape import specify_broadcastable -from pytensor.tensor.type import DenseTensorType, tensor - - -_logger = logging.getLogger("pytensor.tensor.blas") - - -def view_roots(node: Variable) -> list[Variable]: - """Return the leaves from a search through consecutive view-maps.""" - owner = node.owner - if owner is not None: - try: - vars_to_views = {owner.outputs[o]: i for o, i in owner.op.view_map.items()} - except AttributeError: - return [node] - if node in vars_to_views: - answer = [] - for i in vars_to_views[node]: - answer += view_roots(owner.inputs[i]) - return answer - else: - return [node] - else: - return [node] - - -def must_initialize_y_gemv(): - # Check whether Scipy GEMV could output nan if y in not initialized - from scipy.linalg.blas import get_blas_funcs - - if must_initialize_y_gemv._result is None: - y = np.full((2,), np.nan) - x = np.ones((2,)) - A = np.ones((2, 2)) - gemv = get_blas_funcs("gemv", dtype=y.dtype) - gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True) - must_initialize_y_gemv._result = np.isnan(y).any() - - return must_initialize_y_gemv._result - - -must_initialize_y_gemv._result = None # type: ignore - - -class Gemv(Op): - """ - expression is beta * y + alpha * A x - - A is matrix - x, y are vectors - alpha, beta are scalars - output is a vector that can be inplace on y - - """ - - __props__ = ("inplace",) - - def __init__(self, inplace): - self.inplace = inplace - if inplace: - self.destroy_map = {0: [0]} - - def __str__(self): - if self.inplace: - return f"{self.__class__.__name__}{{inplace}}" - else: - return f"{self.__class__.__name__}{{no_inplace}}" - - def make_node(self, y, alpha, A, x, beta): - y = as_tensor_variable(y) - x = as_tensor_variable(x) - A = as_tensor_variable(A) - alpha = as_tensor_variable(alpha) - beta = as_tensor_variable(beta) - if y.dtype != A.dtype or y.dtype != x.dtype: - raise TypeError( - "Gemv requires matching dtypes", (y.dtype, A.dtype, x.dtype) - ) - if A.ndim != 2: - raise TypeError("gemv requires matrix for A", A.type) - if x.ndim != 1: - raise TypeError("gemv requires vector for x", x.type) - if y.ndim != 1: - raise TypeError("gemv requires vector for y", y.type) - - inputs = [y, alpha, A, x, beta] - - if any(not isinstance(i.type, DenseTensorType) for i in inputs): - raise NotImplementedError("Only dense tensor types are supported") - - return Apply(self, inputs, [y.type()]) - - def perform(self, node, inputs, out_storage): - from scipy.linalg.blas import get_blas_funcs - - y, alpha, A, x, beta = inputs - if ( - y.shape[0] != 0 - and x.shape[0] != 0 - and y.dtype in {"float32", "float64", "complex64", "complex128"} - ): - gemv = get_blas_funcs("gemv", dtype=y.dtype) - - if A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]: - raise ValueError( - "Incompatible shapes for gemv " - f"(beta * y + alpha * dot(A, x)). y: {y.shape}, A: {A.shape}, x: {x.shape}" - ) - - if beta == 0 and must_initialize_y_gemv(): - # Most BLAS implementations of GEMV ignore y=nan when beta=0 - # PyTensor considers that the correct behavior, - # and even exploits it to avoid copying or initializing outputs. - # By deciding to exploit this, however, it becomes our responsibility - # to ensure the behavior even in the rare cases BLAS deviates, - # or users will get errors, even for graphs that had no nan to begin with. - y.fill(0) - - # Here I suppose that A is in c order. If we don't make it - # explicitly as fortran order, scipy 0.7.2 seam to create - # a copy in fortran order instead of just reshaping it - # and using the trans flag. - # If A is already in fortran order, make it in c order and using the - # trans flag don't seam to cause slowdown. - # out_storage[0][0] = gemv(alpha, A, x, beta, y, - # overwrite_y=self.inplace) - out_storage[0][0] = gemv( - alpha, A.T, x, beta, y, overwrite_y=self.inplace, trans=True - ) - else: - out = np.dot(A, x) - if alpha != 1: - out *= alpha - if beta != 0: - if beta != 1: - out += beta * y - else: - out += y - out_storage[0][0] = np.asarray(out, dtype=y.dtype) - - def infer_shape(self, fgraph, node, input_shapes): - return [input_shapes[0]] - - -gemv_no_inplace = Gemv(inplace=False) -gemv_inplace = Gemv(inplace=True) -# For the user interface. Opt will make them inplace later -gemv = gemv_no_inplace - - -class Ger(Op): - """ - BLAS defines general rank-1 update GER as A <- A + alpha x y' - - for matrix A, scalar alpha, vectors x and y. - - This interface to GER allows non-destructive operation on A via the - `destructive` argument to the constructor. - - """ - - __props__ = ("destructive",) - - def __init__(self, destructive): - self.destructive = destructive - if destructive: - self.destroy_map = {0: [0]} - - def __str__(self): - if self.destructive: - return f"{self.__class__.__name__}{{destructive}}" - else: - return f"{self.__class__.__name__}{{non-destructive}}" - - def make_node(self, A, alpha, x, y): - A = as_tensor_variable(A) - y = as_tensor_variable(y) - x = as_tensor_variable(x) - alpha = as_tensor_variable(alpha) - if not (A.dtype == x.dtype == y.dtype == alpha.dtype): - raise TypeError( - "ger requires matching dtypes", (A.dtype, alpha.dtype, x.dtype, y.dtype) - ) - if alpha.ndim != 0: - raise TypeError("ger requires scalar alpha", alpha.type) - if A.ndim != 2: - raise TypeError("ger requires matrix for A", A.type) - if x.ndim != 1: - raise TypeError("ger requires vector for x", x.type) - if y.ndim != 1: - raise TypeError("ger requires vector for y", y.type) - - if x.dtype not in ("float32", "float64", "complex64", "complex128"): - raise TypeError("only float and complex types supported", x.dtype) - - inputs = [A, alpha, x, y] - if any(not isinstance(i.type, DenseTensorType) for i in inputs): - raise NotImplementedError("Only dense tensor types are supported") - - return Apply(self, inputs, [A.type()]) - - def perform(self, node, inputs, output_storage): - A, alpha, x, y = inputs - if A.size: - # GER doesn't handle zero-sized inputs - ger_func = get_blas_funcs("ger", dtype=A.dtype) - if A.flags["C_CONTIGUOUS"]: - # Work on transposed system to avoid copying - A = ger_func(alpha, y, x, a=A.T, overwrite_a=self.destructive).T - else: - A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive) - output_storage[0][0] = A - - def infer_shape(self, fgraph, node, input_shapes): - return [input_shapes[0]] - - -ger = Ger(destructive=False) -ger_destructive = Ger(destructive=True) - - -def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False): - """Extract a list of compilation flags from config.blas__ldflags. - - Depending on the options, different type of flags will be kept. - It returns a list of libraries against which an Op's object file - should be linked to benefit from a BLAS implementation. - - Parameters - ---------- - libs : bool, optional - Extract flags starting with "-l" (the default is True). - libs_dir : bool, optional - Extract flags starting with "-L" (the default is False). - include_dir : bool, optional - Extract flags starting with "-I" (the default is False). - flags: bool, optional - Extract all the other flags (the default is False). - - Returns - ------- - list of strings - Extracted flags. - - """ - ldflags_str = config.blas__ldflags - return _ldflags( - ldflags_str=ldflags_str, - libs=libs, - flags=flags, - libs_dir=libs_dir, - include_dir=include_dir, - ) - - -@functools.cache -def _ldflags( - ldflags_str: str, libs: bool, flags: bool, libs_dir: bool, include_dir: bool -) -> list[str]: - """Extract list of compilation flags from a string. - - Depending on the options, different type of flags will be kept. - - Parameters - ---------- - ldflags_str : string - The string to process. Typically, this will be the content of - `config.blas__ldflags`. - libs : bool - Extract flags starting with "-l". - flags: bool - Extract all the other flags. - libs_dir: bool - Extract flags starting with "-L". - include_dir: bool - Extract flags starting with "-I". - - Returns - ------- - list of strings - Extracted flags. - - """ - rval = [] - if libs_dir: - found_dyn = False - dirs = [x[2:] for x in shlex.split(ldflags_str) if x.startswith("-L")] - l = _ldflags( - ldflags_str=ldflags_str, - libs=True, - flags=False, - libs_dir=False, - include_dir=False, - ) - for d in dirs: - for f in Path(d.strip('"')).iterdir(): - if f.suffix in {".so", ".dylib", ".dll"}: - if any(f.stem.find(ll) >= 0 for ll in l): - found_dyn = True - # Special treatment of clang framework. Specifically for MacOS Accelerate - if "-framework" in l and "Accelerate" in l: - found_dyn = True - if not found_dyn and dirs: - _logger.warning( - "We did not find a dynamic library in the " - "library_dir of the library we use for blas. If you use " - "ATLAS, make sure to compile it with dynamics library." - ) - - split_flags = shlex.split(ldflags_str) - skip = False - for pos, t in enumerate(split_flags): - if skip: - skip = False - continue - # Remove extra quote. - if (t.startswith("'") and t.endswith("'")) or ( - t.startswith('"') and t.endswith('"') - ): - t = t[1:-1] - - try: - t0, t1 = t[0], t[1] - assert t0 == "-" or Path(t).exists() - except Exception: - raise ValueError(f'invalid token "{t}" in ldflags_str: "{ldflags_str}"') - if t == "-framework": - skip = True - # Special treatment of clang framework. Specifically for MacOS Accelerate - # The clang framework implicitly adds: header dirs, libraries, and library dirs. - # If we choose to always return these flags, we run into a huge deal amount of - # incompatibilities. For this reason, we only return the framework if libs are - # requested. - if ( - libs - and len(split_flags) >= pos - and split_flags[pos + 1] == "Accelerate" - ): - # We only add the Accelerate framework, but in the future we could extend it to - # other frameworks - rval.append(t) - rval.append(split_flags[pos + 1]) - elif libs_dir and t1 == "L": - rval.append(t[2:]) - elif include_dir and t1 == "I": - raise ValueError( - "Include dirs are not used for blas. We disable" - " this as this can hide other headers and this" - " is not wanted.", - t, - ) - elif libs and t1 == "l": # example -lmkl - rval.append(t[2:]) - elif flags and t1 not in ("L", "I", "l"): # example -openmp - rval.append(t) - elif flags and t1 == "L": - # to find it when we load the compiled op if the env of the - # used is not well configured. - rval.append("-Wl,-rpath," + t[2:]) - return rval - - -class GemmRelated(COp): - """Base class for Gemm and Dot22. - - This class provides a kind of templated gemm Op. - - """ - - __props__: tuple[str, ...] = () - - def c_support_code(self, **kwargs): - # return cblas_header_text() - mod_str = """ - #ifndef MOD - #define MOD % - #endif - void compute_strides(npy_intp *shape, int N_shape, int type_size, npy_intp *res) { - int s; - res[N_shape - 1] = type_size; - for (int i = N_shape - 1; i > 0; i--) { - s = shape[i]; - res[i - 1] = res[i] * (s > 0 ? s : 1); - } - } - """ - return blas_header_text() + mod_str - - def c_headers(self, **kwargs): - return [] - - def c_libraries(self, **kwargs): - return ldflags() - - # code_cache_version is built by subclasses from - # build_gemm_version - - def c_compile_args(self, **kwargs): - return ldflags(libs=False, flags=True) - - def c_lib_dirs(self, **kwargs): - return ldflags(libs=False, libs_dir=True) - - def c_header_dirs(self, **kwargs): - return ldflags(libs=False, include_dir=True) - - declare_NS = """ - int unit = 0; - - int type_num = PyArray_DESCR(%(_x)s)->type_num; - int type_size = PyArray_ITEMSIZE(%(_x)s); // in bytes - - npy_intp* Nx = PyArray_DIMS(%(_x)s); - npy_intp* Ny = PyArray_DIMS(%(_y)s); - npy_intp* Nz = 0; //PyArray_DIMS(%(_zout)s); - - npy_intp* Sx = PyArray_STRIDES(%(_x)s); - npy_intp* Sy = PyArray_STRIDES(%(_y)s); - npy_intp* Sz = 0; //PyArray_STRIDES(%(_zout)s); - - //strides for x, y, z in dimensions 0, 1 - int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1; - """ - - # implement if you don't have an inplace props - # setup_z_Nz_Sz = None - # otherwise implement - # setup_z_Nz_Sz_inplace = None - # setup_z_Nz_Sz_outplace = None - - check_xyz_rank2 = """ - if (PyArray_NDIM(%(_x)s) != 2) { - PyErr_Format(PyExc_NotImplementedError, - "rank(x) != 2. rank(x) is %%d.", - PyArray_NDIM(%(_x)s)); - %(fail)s; - } - if (PyArray_NDIM(%(_y)s) != 2) { - PyErr_Format(PyExc_NotImplementedError, - "rank(y) != 2. rank(y) is %%d.", PyArray_NDIM(%(_y)s)); - %(fail)s; - } - if (%(_zout)s && PyArray_NDIM(%(_zout)s) != 2) { - PyErr_Format(PyExc_NotImplementedError, - "rank(z) != 2. rank(z) is %%d.", PyArray_NDIM(%(_zout)s)); - %(fail)s; - } - """ - check_xyz_double_or_float = """ - if ((PyArray_DESCR(%(_x)s)->type_num != NPY_DOUBLE) - && (PyArray_DESCR(%(_x)s)->type_num != NPY_FLOAT)) - {PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;} - - if ((PyArray_DESCR(%(_y)s)->type_num != NPY_DOUBLE) - && (PyArray_DESCR(%(_y)s)->type_num != NPY_FLOAT)) - {PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;} - - if ((PyArray_DESCR(%(_zout)s)->type_num != NPY_DOUBLE) - && (PyArray_DESCR(%(_zout)s)->type_num != NPY_FLOAT)) - {PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;} - - if ((PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_y)s)->type_num) - ||(PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_zout)s)->type_num)) - { PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; } - """ - - # it is not necessary that a or b have the same type as x,y,z - check_ab_double_or_float = """ - if ((PyArray_DESCR(%(_a)s)->type_num != NPY_DOUBLE) - && (PyArray_DESCR(%(_a)s)->type_num != NPY_FLOAT)) - {PyErr_SetString(PyExc_NotImplementedError, "type(a) is not double or float"); %(fail)s;} - - if ((PyArray_DESCR(%(_b)s)->type_num != NPY_DOUBLE) - && (PyArray_DESCR(%(_b)s)->type_num != NPY_FLOAT)) - {PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;} - """ - - # broadcast_xy = None - - check_dims = """ - if (Nx[0] !=1 && Nz[0] != 1 && Nx[0] != Nz[0]) - { - PyErr_Format(PyExc_ValueError, - "Shape mismatch: x has %%ld rows but z has %%ld rows", - (long int)Nx[0], (long int)Nz[0]); - %(fail)s; - } - if (Nx[1] != Ny[0]) - { - PyErr_Format(PyExc_ValueError, - "Shape mismatch: x has %%ld cols (and %%ld rows) but y has %%ld rows (and %%ld cols)", - (long int)Nx[1], (long int)Nx[0], (long int)Ny[0], (long int)Ny[1]); - %(fail)s; - } - if (Ny[1] != 1 && Nz[1]!= 1 && Ny[1] != Nz[1]) - { - PyErr_Format(PyExc_ValueError, - "Shape mismatch: y has %%ld cols but z has %%ld cols", - (long int)Ny[1], (long int)Nz[1]); - %(fail)s; - } - - // We must not raise an error when Nx[1] == 0. This would disable cases - // that numpy.dot accept. - """ - - check_strides = """ - /* - If some matrices are not contiguous on either dimensions, - or have invalid strides, copy their content into a contiguous one - */ - if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size) - || ((Sx[0] != type_size) && (Sx[1] != type_size))) - { - PyArrayObject * _x_copy = (PyArrayObject *) PyArray_Copy(%(_x)s); - if (!_x_copy) - %(fail)s - Py_XDECREF(%(_x)s); - %(_x)s = _x_copy; - Sx = PyArray_STRIDES(%(_x)s); - if ((Sx[0] < 1) || (Sx[1] < 1)) { - compute_strides(Nx, 2, type_size, Sx); - } - } - - if ((Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size) - || ((Sy[0] != type_size) && (Sy[1] != type_size))) - { - PyArrayObject * _y_copy = (PyArrayObject *) PyArray_Copy(%(_y)s); - if (!_y_copy) - %(fail)s - Py_XDECREF(%(_y)s); - %(_y)s = _y_copy; - Sy = PyArray_STRIDES(%(_y)s); - if ((Sy[0] < 1) || (Sy[1] < 1)) { - compute_strides(Ny, 2, type_size, Sy); - } - } - - if ((Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size) - || ((Sz[0] != type_size) && (Sz[1] != type_size))) - { - PyArrayObject * _z_copy = (PyArrayObject *) PyArray_Copy(%(_zout)s); - if (!_z_copy) - %(fail)s - Py_XDECREF(%(_zout)s); - %(_zout)s = _z_copy; - Sz = PyArray_STRIDES(%(_zout)s); - if ((Sz[0] < 1) || (Sz[1] < 1)) { - compute_strides(Nz, 2, type_size, Sz); - } - } - """ - - encode_strides_in_unit = """ - /* - encode the stride structure of _x,_y,_zout into a single integer - */ - unit |= ((Sx[1] == type_size || Nx[1]==1) ? 0x0 : (Sx[0] == type_size || Nx[0]==1) ? 0x1 : 0x2) << 8; - unit |= ((Sy[1] == type_size || Ny[1]==1) ? 0x0 : (Sy[0] == type_size || Ny[0]==1) ? 0x1 : 0x2) << 4; - unit |= ((Sz[1] == type_size || Nz[1]==1) ? 0x0 : (Sz[0] == type_size || Nz[0]==1) ? 0x1 : 0x2) << 0; - """ - - compute_strides = """ - /* create appropriate strides for malformed matrices that are row or column - * vectors, or empty matrices. - * In that case, the value of the stride does not really matter, but - * some versions of BLAS insist that: - * - they are not smaller than the number of elements in the array, - * - they are not 0. - */ - sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : (Nx[1] + 1); - sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : (Nx[0] + 1); - sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : (Ny[1] + 1); - sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : (Ny[0] + 1); - sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : (Nz[1] + 1); - sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : (Nz[0] + 1); - """ - - begin_switch_typenum = """ - switch (type_num) - { - """ - - case_float = """ - case NPY_FLOAT: - { - """ - - # case_float_ab_constants = None - - case_float_gemm = """ - float* x = (float*)PyArray_DATA(%(_x)s); - float* y = (float*)PyArray_DATA(%(_y)s); - float* z = (float*)PyArray_DATA(%(_zout)s); - char N = 'N'; - char T = 'T'; - int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1]; - switch(unit) - { - case 0x000: sgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break; - case 0x100: sgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break; - case 0x010: sgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break; - case 0x110: sgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break; - case 0x001: sgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break; - case 0x101: sgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break; - case 0x011: sgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break; - case 0x111: sgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break; - default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s; - }; - """ - - case_double = """ - } - break; - case NPY_DOUBLE: - { - """ - - # case_double_ab_constants = None - - case_double_gemm = """ - double* x = (double*)PyArray_DATA(%(_x)s); - double* y = (double*)PyArray_DATA(%(_y)s); - double* z = (double*)PyArray_DATA(%(_zout)s); - char N = 'N'; - char T = 'T'; - int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1]; - switch(unit) - { - case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, - &sy_0, x, &sx_0, &b, z, &sz_0); break; - case 0x100: dgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, - &sy_0, x, &sx_1, &b, z, &sz_0); break; - case 0x010: dgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, - &sy_1, x, &sx_0, &b, z, &sz_0); break; - case 0x110: dgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, - &sy_1, x, &sx_1, &b, z, &sz_0); break; - case 0x001: dgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, - &sx_0, y, &sy_0, &b, z, &sz_1); break; - case 0x101: dgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, - &sx_1, y, &sy_0, &b, z, &sz_1); break; - case 0x011: dgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, - &sx_0, y, &sy_1, &b, z, &sz_1); break; - case 0x111: dgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, - &sx_1, y, &sy_1, &b, z, &sz_1); break; - default: PyErr_SetString(PyExc_ValueError, - "some matrix has no unit stride"); - %(fail)s; - }; - """ - - end_switch_typenum = """ - } - break; - } - """ - - def build_gemm_call(self): - if hasattr(self, "inplace"): - setup_z_Nz_Sz = f"if(%(params)s->inplace){{{self.setup_z_Nz_Sz_inplace}}}else{{{self.setup_z_Nz_Sz_outplace}}}" - else: - setup_z_Nz_Sz = self.setup_z_Nz_Sz - - return "".join( - ( - self.declare_NS, - self.check_xyz_rank2, - setup_z_Nz_Sz, - self.check_xyz_double_or_float, - self.check_ab_double_or_float, - self.broadcast_xy, - self.check_dims, - self.check_strides, - self.encode_strides_in_unit, - self.compute_strides, - self.begin_switch_typenum, - self.case_float, - self.case_float_ab_constants, - self.case_float_gemm, - self.case_double, - self.case_double_ab_constants, - self.case_double_gemm, - self.end_switch_typenum, - ) - ) - - def build_gemm_version(self): - return (14, blas_header_version()) - - -class Gemm(GemmRelated): - """In-place version of matrix-matrix multiplication (with accumulation). - - When a and b are scalars and x, y, and z are matrices, then - - gemm(z,a,x,y,b) - - is similar to - - b*z + a*dot(x,y) - - The difference between the two is that the top form is destructive - on z, whereas the bottom form is not. Gemm works in-place on the - storage associated with z, and the L{Variable} returned by Gemm - has a storage that will be aliased to the storage of the z - argument. Because of this in-place computation, an L{Apply} of - this op will destroy the L{Variable} z on which it operates. (See - L{DestructiveOps} for an explanation of what destroying means in - the context of pytensor graphs. See L{BlasLapackSupport} for more - optimized linear algebra operations.) - - """ - - E_rank = "gemm only works for rank 2" - E_scalar = "gemm requires scalar argument" - E_z_uniq = "argument z aliased to x or y" # TODO: justify / delete this - E_mixed = "gemm requires matching dtypes" - E_float = "gemm requires floating-point dtypes" - - __props__ = ("inplace",) - params_type = ParamsType( - inplace=bool_t, - ) - check_input = False - - def __init__(self, inplace): - self.inplace = inplace - if self.inplace: - self.destroy_map = {0: [0]} - - def __str__(self): - if self.inplace: - inplace_str = "inplace" - else: - inplace_str = "no_inplace" - return f"{self.__class__.__name__}{{{inplace_str}}}" - - def __setstate__(self, dct): - self.__dict__.update(dct) - - # Correctly reload older pickles where destroy_map were not - # saved - if "destroy_map" not in self.__dict__ and self.inplace: - self.destroy_map = {0: [0]} - - def __getstate__(self): - rval = self.__dict__.copy() - # Do not serialize the setup code, it will be restored in __setstate__ - # depending on the value of 'inplace' - rval.pop("setup_z_Nz_Sz", None) - return rval - - def make_node(self, *inputs): - inputs = list(map(as_tensor_variable, inputs)) - - if any(not isinstance(i.type, DenseTensorType) for i in inputs): - raise NotImplementedError("Only dense tensor types are supported") - - if len(inputs) != 5: - raise TypeError( - f"Wrong number of inputs for {self} (expected 5, got {len(inputs)})" - ) - z, a, x, y, b = inputs - - zr, xr, yr = (set(view_roots(i)) for i in (z, x, y)) - - # We want the gemm to be inplace. When this op is inplace, it - # declare to be inplace only on z. So to make it safe, we - # raise an error if z can be a view on x or y. - - # I don't know if PyTensor currently can support that case. As - # this case don't happen in our code, I won't spent time - # investigating this. So the assert is for safety. I also - # think there is another mechanism that would prevent this, - # but I don't what to modify old code and have chance to break - # something. - if self.inplace: - if zr.intersection(xr): - raise InconsistencyError(Gemm.E_z_uniq, (z, x)) - if zr.intersection(yr): - raise InconsistencyError(Gemm.E_z_uniq, (z, y)) - - if z.ndim != 2: - raise TypeError(Gemm.E_rank, z) - if a.ndim != 0: - raise TypeError(Gemm.E_scalar, a) - if x.ndim != 2: - raise TypeError(Gemm.E_rank, x) - if y.ndim != 2: - raise TypeError(Gemm.E_rank, y) - if b.ndim != 0: - raise TypeError(Gemm.E_scalar, b) - - if not (z.dtype == a.dtype == x.dtype == y.dtype == b.dtype): - raise TypeError(Gemm.E_mixed, (z.dtype, a.dtype, x.dtype, y.dtype, b.dtype)) - - if not z.dtype.startswith("float") and not z.dtype.startswith("complex"): - raise TypeError(Gemm.E_float, (z.dtype)) - - output = z.type() - return Apply(self, inputs, [output]) - - def perform(self, node, inp, out): - z, a, x, y, b = inp - (zout,) = out - assert a.shape == () - assert b.shape == () - if not self.inplace: - z = z.copy() # the original z will not be changed - if z.shape == (): - z.itemset(z * a + b * np.dot(x, y)) - zout[0] = z - else: - # Broadcast Z if needed - if (x.shape[0] > z.shape[0]) or (y.shape[1] > z.shape[1]): - z = np.broadcast_to( - z, (max(x.shape[0], z.shape[0]), max(y.shape[1], z.shape[1])) - ).copy() - if b == 0.0: - if a == 1.0: - z[:] = np.dot(x, y) - elif a == -1.0: - z[:] = -np.dot(x, y) - else: - z[:] = a * np.dot(x, y) - elif b == 1.0: - if a == 1.0: - z += np.dot(x, y) - elif a == -1.0: - z -= np.dot(x, y) - else: - z += a * np.dot(x, y) - else: - z *= b - z += a * np.dot(x, y) - zout[0] = z - - def infer_shape(self, fgraph, node, input_shapes): - z_shape, _, x_shape, y_shape, _ = input_shapes - return [ - ( - pytensor.scalar.maximum(z_shape[0], x_shape[0]), - pytensor.scalar.maximum(z_shape[1], y_shape[1]), - ) - ] - - setup_z_Nz_Sz_inplace = """ - // Needs broadcasting - if (PyArray_DIMS(%(_z)s)[0] < Nx[0] || PyArray_DIMS(%(_z)s)[1] < Ny[1]){ - - npy_intp dims[2]; - dims[0] = (PyArray_DIMS(%(_z)s)[0] >= Nx[0]) ? PyArray_DIMS(%(_z)s)[0] : Nx[0]; - dims[1] = (PyArray_DIMS(%(_z)s)[1] >= Ny[1]) ? PyArray_DIMS(%(_z)s)[1] : Ny[1]; - - // Check if we need to allocate new array - if((NULL == %(_zout)s) - || (PyArray_DIMS(%(_zout)s)[0] != dims[0]) - || (PyArray_DIMS(%(_zout)s)[1] != dims[1])) - { - // fprintf(stderr, "Gemm Allocating z output array with shape (%%i %%i)\\n", dims[0], dims[1]); - Py_XDECREF(%(_zout)s); - %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_z)s)); - } - - // fprintf(stderr, "Gemm Broadcasting Z into shape (%%i %%i)\\n", dims[0], dims[1]); - if(PyArray_CopyInto(%(_zout)s, %(_z)s) == -1) - { - %(fail)s; - } - - } else { - if (%(_zout)s != %(_z)s) - { - Py_XDECREF(%(_zout)s); - %(_zout)s = %(_z)s; - Py_INCREF(%(_zout)s); - } - } - - Nz = PyArray_DIMS(%(_zout)s); - Sz = PyArray_STRIDES(%(_zout)s); - """ - - setup_z_Nz_Sz_outplace = """ - npy_intp dims[2]; - dims[0] = (PyArray_DIMS(%(_z)s)[0] >= Nx[0]) ? PyArray_DIMS(%(_z)s)[0] : Nx[0]; - dims[1] = (PyArray_DIMS(%(_z)s)[1] >= Ny[1]) ? PyArray_DIMS(%(_z)s)[1] : Ny[1]; - - // Check if we need to allocate new array - if ((NULL == %(_zout)s) - || (PyArray_DIMS(%(_zout)s)[0] != dims[0]) - || (PyArray_DIMS(%(_zout)s)[1] != dims[1])) - { - Py_XDECREF(%(_zout)s); - %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_z)s)); - // fprintf(stderr, "Gemm Allocating z output array with shape (%%i %%i)\\n", dims[0], dims[1]); - if(!%(_zout)s) { - PyErr_SetString(PyExc_MemoryError, - "failed to alloc gemm_no_inplace output"); - %(fail)s - } - } - - // fprintf(stderr, "Gemm Broadcasting Z into shape (%%i %%i)\\n", dims[0], dims[1]); - if(PyArray_CopyInto(%(_zout)s, %(_z)s) == -1) - { - %(fail)s - } - - Nz = PyArray_DIMS(%(_zout)s); - Sz = PyArray_STRIDES(%(_zout)s); - """ - - broadcast_xy = """ - // Broadcast X if needed - if (Nz[0] > Nx[0]) - { - npy_intp dims[2]; - dims[0] = Nz[0]; - dims[1] = Nx[1]; - // fprintf(stderr, "Gemm Broadcasting X into shape (%%i %%i)\\n", dims[0], dims[1]); - PyArrayObject *x_new = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_x)s)); - if(!x_new) { - PyErr_SetString(PyExc_MemoryError, - "failed to alloc gemm_inplace input"); - %(fail)s - } - - if(PyArray_CopyInto(x_new, %(_x)s) == -1) - { - %(fail)s - } - - Py_DECREF(%(_x)s); - %(_x)s = x_new; - - Nx = PyArray_DIMS(%(_x)s); - Sx = PyArray_STRIDES(%(_x)s); - } - - // Broadcast Y if needed - if (Nz[1] > Ny[1]) - { - npy_intp dims[2]; - dims[0] = Ny[0]; - dims[1] = Nz[1]; - // fprintf(stderr, "Gemm Broadcasting Y into shape (%%i %%i)\\n", dims[0], dims[1]); - PyArrayObject *y_new = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_x)s)); - if(!y_new) { - PyErr_SetString(PyExc_MemoryError, - "failed to alloc gemm_inplace input"); - %(fail)s - } - - if(PyArray_CopyInto(y_new, %(_y)s) == -1) - { - %(fail)s - } - - Py_DECREF(%(_y)s); - %(_y)s = y_new; - - Ny = PyArray_DIMS(%(_y)s); - Sy = PyArray_STRIDES(%(_y)s); - } - - """ - - case_float_ab_constants = """ - #define REAL float - float a = (PyArray_DESCR(%(_a)s)->type_num == NPY_FLOAT) - ? (REAL)(((float*)PyArray_DATA(%(_a)s))[0]) - : (REAL)(((double*)PyArray_DATA(%(_a)s))[0]); - float b = (PyArray_DESCR(%(_b)s)->type_num == NPY_FLOAT) ? - (REAL)(((float*)PyArray_DATA(%(_b)s))[0]) - : (REAL)(((double*)PyArray_DATA(%(_b)s))[0]); - #undef REAL - """ - case_double_ab_constants = """ - #define REAL double - double a = (PyArray_DESCR(%(_a)s)->type_num == NPY_FLOAT) - ? (REAL)(((float*)PyArray_DATA(%(_a)s))[0]) - : (REAL)(((double*)PyArray_DATA(%(_a)s))[0]); - double b = (PyArray_DESCR(%(_b)s)->type_num == NPY_FLOAT) ? - (REAL)(((float*)PyArray_DATA(%(_b)s))[0]) - : (REAL)(((double*)PyArray_DATA(%(_b)s))[0]); - #undef REAL - """ - - def c_code(self, node, name, inp, out, sub): - _z, _a, _x, _y, _b = inp - (_zout,) = out - if node.inputs[0].type.dtype.startswith("complex"): - raise MethodNotDefined(f"{self.__class__.__name__}.c_code") - full_code = self.build_gemm_call() % dict(locals(), **sub) - return full_code - - def c_code_cache_version(self): - gv = self.build_gemm_version() - if gv: - return (8, *gv) - else: - return gv - - -gemm_inplace = Gemm(inplace=True) -gemm_no_inplace = Gemm(inplace=False) -# For the user interface. PyTensor optimization will make them inplace -gemm = gemm_no_inplace -pprint.assign(gemm_inplace, FunctionPrinter(["gemm_inplace"])) -pprint.assign(gemm_no_inplace, FunctionPrinter(["gemm_no_inplace"])) - - -class Dot22(GemmRelated): - """Compute a matrix-matrix product. - - This is a specialization of the more general Dot(). - - """ - - check_input = False - - def make_node(self, x, y): - x = as_tensor_variable(x) - y = as_tensor_variable(y) - - if any(not isinstance(i.type, DenseTensorType) for i in (x, y)): - raise NotImplementedError("Only dense tensor types are supported") - - dtypes = ("float16", "float32", "float64", "complex64", "complex128") - if x.type.ndim != 2 or x.type.dtype not in dtypes: - raise TypeError(x) - if y.type.ndim != 2 or y.type.dtype not in dtypes: - raise TypeError(y) - if y.type.dtype != x.type.dtype: - raise TypeError("dtype mismatch to Dot22") - outputs = [tensor(dtype=x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))] - return Apply(self, [x, y], outputs) - - def perform(self, node, inputs, output_storage): - output_storage[0][0] = np.dot(*inputs) - - def infer_shape(self, fgraph, node, input_shapes): - return [[input_shapes[0][0], input_shapes[1][1]]] - - setup_z_Nz_Sz = """ - if ((NULL == %(_zout)s) - || (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_x)s)[0]) - || (PyArray_DIMS(%(_zout)s)[1] != PyArray_DIMS(%(_y)s)[1])) - { - if (NULL != %(_zout)s) Py_XDECREF(%(_zout)s); - npy_intp dims[2]; - dims[0] = PyArray_DIMS(%(_x)s)[0]; - dims[1] = PyArray_DIMS(%(_y)s)[1]; - %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, - PyArray_TYPE(%(_x)s)); - //fprintf(stderr, "Dot Allocating %%i %%i\\n", dims[0], dims[1]); - if(!%(_zout)s) { - PyErr_SetString(PyExc_MemoryError, - "failed to alloc dot22 output"); - %(fail)s - } - } - Nz = PyArray_DIMS(%(_zout)s); - Sz = PyArray_STRIDES(%(_zout)s); - - """ - broadcast_xy = "" - check_ab_double_or_float = "" - case_float_ab_constants = """ - float a = 1.0; - float b = 0.0; - """ - case_double_ab_constants = """ - double a = 1.0; - double b = 0.0; - """ - - def c_code(self, node, name, inp, out, sub): # DEBUG - _x, _y = inp - (_zout,) = out - if node.inputs[0].type.dtype.startswith("complex"): - raise MethodNotDefined(f"{self.__class__.__name__}.c_code") - if len(self.c_libraries()) <= 0: - raise NotImplementedError() - full_code = self.build_gemm_call() % dict(locals(), **sub) - return full_code - - def c_code_cache_version(self): - gv = self.build_gemm_version() - if gv: - return (2, *gv) - else: - return gv - - -_dot22 = Dot22() - - -class Dot22Scalar(GemmRelated): - """Compute a matrix-matrix product. - - This is a specialization of the more general Dot() - Used to call optimized gemm implementation. - Also used to generate a gemm later. - compute scalar*dot(x,y). - - """ - - check_input = False - - def make_node(self, x, y, a): - if any(not isinstance(i.type, DenseTensorType) for i in (x, y, a)): - raise NotImplementedError("Only dense tensor types are supported") - - if a.ndim != 0: - raise TypeError(Gemm.E_scalar, a) - if x.ndim != 2: - raise TypeError(Gemm.E_rank, x) - if y.ndim != 2: - raise TypeError(Gemm.E_rank, y) - - if not (a.dtype == x.dtype == y.dtype): - raise TypeError( - "Dot22Scalar requires matching dtypes", (a.dtype, x.dtype, y.dtype) - ) - - if not a.dtype.startswith("float") and not a.dtype.startswith("complex"): - raise TypeError("Dot22Scalar requires float or complex args", a.dtype) - - sz = (x.type.shape[0], y.type.shape[1]) - outputs = [tensor(dtype=x.type.dtype, shape=sz)] - return Apply(self, [x, y, a], outputs) - - def perform(self, node, inp, out): - x, y, scalar = inp - (z,) = out - try: - z[0] = np.asarray(scalar * np.dot(x, y)) - except ValueError as e: - # The error raised by numpy has no shape information, we - # mean to add that - e.args = (*e.args, x.shape, y.shape) - raise - - def infer_shape(self, fgraph, node, input_shapes): - return [[input_shapes[0][0], input_shapes[1][1]]] - - setup_z_Nz_Sz = Dot22.setup_z_Nz_Sz - broadcast_xy = "" - - check_ab_double_or_float = """ - if ((PyArray_DESCR(%(_a)s)->type_num != NPY_DOUBLE) - && (PyArray_DESCR(%(_a)s)->type_num != NPY_FLOAT)) - {PyErr_SetString(PyExc_NotImplementedError, - "type(a) is not double or float"); %(fail)s;} - - """ - case_float_ab_constants = """ - #define REAL float - float a = (PyArray_DESCR(%(_a)s)->type_num == NPY_FLOAT) - ? (REAL)(((float*)PyArray_DATA(%(_a)s))[0]) - : (REAL)(((double*)PyArray_DATA(%(_a)s))[0]); - #undef REAL - float b = 0.0; - """ - - case_double_ab_constants = """ - #define REAL double - double a = (PyArray_DESCR(%(_a)s)->type_num == NPY_FLOAT) - ? (REAL)(((float*)PyArray_DATA(%(_a)s))[0]) - : (REAL)(((double*)PyArray_DATA(%(_a)s))[0]); - #undef REAL - double b = 0.0; - """ - - def c_code(self, node, name, inp, out, sub): - _x, _y, _a = inp - (_zout,) = out - if node.inputs[0].type.dtype.startswith("complex"): - raise MethodNotDefined(f"{self.__class__.__name__}.c_code") - if len(self.c_libraries()) <= 0: - raise NotImplementedError() - full_code = self.build_gemm_call() % dict(locals(), **sub) - return full_code - - def c_code_cache_version(self): - gv = self.build_gemm_version() - if gv: - return (2, *gv) - else: - return gv - - -_dot22scalar = Dot22Scalar() - - -class BatchedDot(COp): - """ - Computes a batch matrix-matrix dot with tensor3 variables - - batched_dot(a, b)[i] = dot(a[i], b[i]) - """ - - __props__ = () - gufunc_signature = "(b,m,k),(b,k,n)->(b,m,n)" - - def make_node(self, x, y): - x = as_tensor_variable(x) - y = as_tensor_variable(y) - - if not ( - isinstance(x.type, DenseTensorType) and isinstance(y.type, DenseTensorType) - ): - raise NotImplementedError("Only dense tensor types are supported") - - if not (x.type.ndim == 3 and y.type.ndim == 3): - raise TypeError( - f"Inputs must have 3 ndim, but got {x.type.ndim} and {y.type.ndim}. " - "Consider calling batched_dot instead." - ) - - def extract_static_dim(dim_x, dim_y): - dims = {dim_x, dim_y} - {None} - if len(dims) > 1: - # BatchedDot doesn't allow broadcasting - raise ValueError( - f"Static dimensions of BatchedDot don't match, got {x.type.shape} and {y.type.shape}" - ) - elif not dims: - return None - else: - return dims.pop() - - x_batch_dim, x_row_dim, x_sum_dim = x.type.shape - y_batch_dim, y_sum_dim, y_col_dim = y.type.shape - batch_dim = extract_static_dim(x_batch_dim, y_batch_dim) - # Raise if static sum dimensions do not match - _ = extract_static_dim(x_sum_dim, y_sum_dim) - out_shape = (batch_dim, x_row_dim, y_col_dim) - - # Change dtype if needed - dtype = pytensor.scalar.upcast(x.type.dtype, y.type.dtype) - x, y = cast(x, dtype), cast(y, dtype) - out = tensor(dtype=dtype, shape=out_shape) - return Apply(self, [x, y], [out]) - - def perform(self, node, inp, out): - x, y = inp - (z,) = out - - if x.shape[0] != y.shape[0]: - raise TypeError( - f"Inputs [{', '.join(map(str, inp))}] must have the" - f" same size in axis 0, but have sizes [{', '.join(str(i.shape[0]) for i in inp)}]." - ) - - z[0] = np.matmul(x, y) - - def c_support_code(self, **kwargs): - batch_gemm_defn = """ - template - bool batch_gemm(void (*gemm)(char*, char*, const int*, const int*, const int*, const dtype*, const dtype*, const int*, const dtype*, const int*, const dtype*, dtype*, const int*), - int type_size, PyArrayObject* xs, PyArrayObject* ys, - PyArrayObject* zs) { - npy_intp *Nx = PyArray_DIMS(xs), *Sx = PyArray_STRIDES(xs); - npy_intp *Ny = PyArray_DIMS(ys), *Sy = PyArray_STRIDES(ys); - npy_intp *Nz = PyArray_DIMS(zs), *Sz = PyArray_STRIDES(zs); - - if (Nx[0] != Ny[0]) { - PyErr_Format(PyExc_ValueError, - "Shape mismatch: batch sizes unequal." - " x.shape is (%d, %d, %d)," - " y.shape is (%d, %d, %d).", - Nx[0], Nx[1], Nx[2], - Ny[0], Ny[1], Ny[2]); - return 1; - } - - if (Nx[2] != Ny[1]) { - PyErr_Format(PyExc_ValueError, - "Shape mismatch: summation axis sizes unequal." - " x.shape is (%d, %d, %d)," - " y.shape is (%d, %d, %d).", - Nx[0], Nx[1], Nx[2], - Ny[0], Ny[1], Ny[2]); - return 1; - } - - /* encode the stride structure of _x,_y,_z into a single integer. */ - int unit = 0; - unit |= ((Sx[2] == type_size || Nx[2] == 1) ? 0x0 : (Sx[1] == type_size || Nx[1]==1) ? 0x1 : 0x2) << 8; - unit |= ((Sy[2] == type_size || Ny[2] == 1) ? 0x0 : (Sy[1] == type_size || Ny[1]==1) ? 0x1 : 0x2) << 4; - unit |= ((Sz[2] == type_size || Nz[2] == 1) ? 0x0 : (Sz[1] == type_size || Nz[1]==1) ? 0x1 : 0x2) << 0; - - /* create appropriate strides for malformed matrices that are row or column - * vectors, or empty matrices. - * In that case, the value of the stride does not really matter, but - * some versions of BLAS insist that: - * - they are not smaller than the number of elements in the array, - * - they are not 0. - */ - int sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : (Nx[2] + 1); - int sx_2 = (Nx[2] > 1) ? Sx[2]/type_size : (Nx[1] + 1); - int sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : (Ny[2] + 1); - int sy_2 = (Ny[2] > 1) ? Sy[2]/type_size : (Ny[1] + 1); - int sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : (Nz[2] + 1); - int sz_2 = (Nz[2] > 1) ? Sz[2]/type_size : (Nz[1] + 1); - - dtype* x = (dtype*)PyArray_DATA(xs); - dtype* y = (dtype*)PyArray_DATA(ys); - dtype* z = (dtype*)PyArray_DATA(zs); - - dtype a = 1.0; - dtype b = 0.0; - char N = 'N'; - char T = 'T'; - int Nz1 = Nz[1], Nz2 = Nz[2], Nx2 = Nx[2]; - - // loop over batch axis - for (int i = 0; i < Nz[0]; i++) { - switch(unit) - { - case 0x000: gemm(&N, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_1, &b, z, &sz_1); break; - case 0x100: gemm(&N, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_2, &b, z, &sz_1); break; - case 0x010: gemm(&T, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_1, &b, z, &sz_1); break; - case 0x110: gemm(&T, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_2, &b, z, &sz_1); break; - case 0x001: gemm(&T, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_1, &b, z, &sz_2); break; - case 0x101: gemm(&N, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_1, &b, z, &sz_2); break; - case 0x011: gemm(&T, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_2, &b, z, &sz_2); break; - case 0x111: gemm(&N, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_2, &b, z, &sz_2); break; - default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); return 1; - }; - x += Sx[0] / type_size; - y += Sy[0] / type_size; - z += Sz[0] / type_size; - } - - return 0; - } - """ - return blas_header_text() + batch_gemm_defn - - def c_libraries(self, **kwargs): - return ldflags() - - def c_compile_args(self, **kwargs): - return ldflags(libs=False, flags=True) - - def c_lib_dirs(self, **kwargs): - return ldflags(libs=False, libs_dir=True) - - def c_header_dirs(self, **kwargs): - return ldflags(libs=False, include_dir=True) - - def c_code(self, node, name, inp, out, sub): - # Can only compile if linked to blas libraries - if len(self.c_libraries()) <= 0: - raise NotImplementedError() - - _x, _y = inp - (_z,) = out - fail = sub["fail"] - - # generate contiguity condition - def contiguous(var, ndim): - strides = f"PyArray_STRIDES({var})" - if ndim == 1: - return f"{strides}[0] == type_size" - ands = " && ".join( - f"{strides}[{i}] > 0 && {strides}[{i}] % type_size == 0" - for i in range(1, ndim) - ) - ors = " || ".join(f"{strides}[{i}] == type_size" for i in range(1, ndim)) - return f"{ands} && ({ors})" - - x_ndim, y_ndim, z_ndim = ( - node.inputs[0].ndim, - node.inputs[1].ndim, - node.outputs[0].ndim, - ) - - # generate code to allocate output based on runtime input shapes - z_dims = [ - f"PyArray_DIMS({_x})[0]", - f"PyArray_DIMS({_x})[1]", - f"PyArray_DIMS({_y})[2]", - ] - - z_shape_correct = " && ".join( - f"PyArray_DIMS({_z})[{i}] == {dim}" for i, dim in enumerate(z_dims) - ) - z_shape = ", ".join(z_dims) - z_contiguous = contiguous(_z, z_ndim) - allocate = f""" - if (NULL == {_z} || !({z_shape_correct}) || !({z_contiguous})) - {{ - npy_intp dims[{z_ndim}] = {{{z_shape}}}; - Py_XDECREF({_z}); - {_z} = (PyArrayObject*)PyArray_SimpleNew( - {z_ndim}, dims, PyArray_TYPE({_x})); - if(!{_z}) {{ - PyErr_SetString(PyExc_MemoryError, - "failed to alloc BatchedDot output"); - {fail} - }} - }} - """ - - # code to reallocate inputs contiguously if necessary - contiguate = [] - for var, ndim in [(_x, x_ndim), (_y, y_ndim)]: - _contiguous = contiguous(var, ndim) - contiguate.append( - f""" - if (!({_contiguous})) {{ - PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy({var}); - if (!_copy) - {fail} - Py_XDECREF({var}); - {var} = _copy; - }} - """ - ) - contiguate = "\n".join(contiguate) - - return f""" - int type_num = PyArray_DESCR({_x})->type_num; - int type_size = PyArray_ITEMSIZE({_x}); // in bytes - - if (PyArray_NDIM({_x}) != 3) {{ - PyErr_Format(PyExc_NotImplementedError, - "rank(x) != 3. rank(x) is %d.", - PyArray_NDIM({_x})); - {fail}; - }} - if (PyArray_NDIM({_y}) != 3) {{ - PyErr_Format(PyExc_NotImplementedError, - "rank(y) != 3. rank(y) is %d.", - PyArray_NDIM({_y})); - {fail}; - }} - if ({_z} && PyArray_NDIM({_z}) != 3) {{ - PyErr_Format(PyExc_NotImplementedError, - "rank(z) != 3. rank(z) is %d.", - PyArray_NDIM({_z})); - {fail}; - }} - - // allocate output - {allocate} - // reallocate any noncontiguous arrays or arrays with invalid strides - {contiguate} - - if ((PyArray_DESCR({_x})->type_num != NPY_DOUBLE) - && (PyArray_DESCR({_x})->type_num != NPY_FLOAT)) - {{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); {fail};}} - - if ((PyArray_DESCR({_y})->type_num != NPY_DOUBLE) - && (PyArray_DESCR({_y})->type_num != NPY_FLOAT)) - {{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); {fail};}} - - if ((PyArray_DESCR({_z})->type_num != NPY_DOUBLE) - && (PyArray_DESCR({_z})->type_num != NPY_FLOAT)) - {{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); {fail};}} - - if ((PyArray_DESCR({_x})->type_num != PyArray_DESCR({_y})->type_num) - ||(PyArray_DESCR({_x})->type_num != PyArray_DESCR({_z})->type_num)) - {{ PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); {fail}; }} - - switch (type_num) - {{ - case NPY_FLOAT: - if (batch_gemm(sgemm_, type_size, {_x}, {_y}, {_z})) {{ - {fail}; - }} - break; - case NPY_DOUBLE: - if (batch_gemm(dgemm_, type_size, {_x}, {_y}, {_z})) {{ - {fail}; - }} - break; - }} - """ - - def c_code_cache_version(self): - from pytensor.tensor.blas_headers import blas_header_version - - return (6, blas_header_version()) - - def pullback(self, inp, outputs, grads): - x, y = inp - (gz,) = grads - - xgrad = _batched_dot(gz, y.dimshuffle(0, 2, 1)) - ygrad = _batched_dot(x.dimshuffle(0, 2, 1), gz) - - # If x or y contain broadcastable dimensions but only one of - # them know that a matching dimensions is broadcastable, the - # above code don't always return the right broadcast pattern. - # This cause problem down the road. See gh-1461. - if xgrad.broadcastable != x.broadcastable: - xgrad = specify_broadcastable( - xgrad, *(ax for (ax, b) in enumerate(x.type.broadcastable) if b) - ) - if ygrad.broadcastable != y.broadcastable: - ygrad = specify_broadcastable( - ygrad, *(ax for (ax, b) in enumerate(y.type.broadcastable) if b) - ) - - return xgrad, ygrad - - def pushforward(self, inputs, outputs, eval_points): - assert len(inputs) == 2 - assert len(eval_points) == 2 - if isinstance(eval_points[0].type, DisconnectedType) and isinstance( - eval_points[1].type, DisconnectedType - ): - return [disconnected_type()] - - if not isinstance(eval_points[0].type, DisconnectedType): - t1 = self(eval_points[0], inputs[1]) - if not isinstance(eval_points[1].type, DisconnectedType): - t2 = self(inputs[0], eval_points[1]) - - if not isinstance(eval_points[0].type, DisconnectedType) and not isinstance( - eval_points[1].type, DisconnectedType - ): - return [t1 + t2] - elif not isinstance(eval_points[0].type, DisconnectedType): - return [t1] - else: - return [t2] - - def infer_shape(self, fgraph, node, shapes): - xshp, yshp = shapes - return [xshp[:-1] + yshp[2:]] - - -_batched_dot = BatchedDot() - - -def batched_dot(a, b): - """Compute the batched dot product of two variables. - - I.e.: - - batched_dot(a, b)[i] = dot(a[i], b[i]) - - Note that this batched_dot function does one of three things, in the - following sequence: - - 1. If either a or b is a vector, it returns the batched elementwise - product without calling the PyTensor BatchedDot op. - - 2. If both a and b have either 2 or 3 dimensions, it calls PyTensor's - BatchedDot op on a and b. - - 3. If either a or b has more than 3 dimensions, it calls PyTensor's - batched_tensordot function with appropriate axes. The - batched_tensordot function expresses high-dimensional batched - dot products in terms of batched matrix-matrix dot products, so - it may be possible to further optimize for performance. - """ - warnings.warn( - "batched_dot is deprecated. " - "Use `dot` in conjunction with `tensor.vectorize` or `graph.replace.vectorize_graph`", - FutureWarning, - ) - a, b = as_tensor_variable(a), as_tensor_variable(b) - - if a.ndim == 0: - raise TypeError("a must have at least one (batch) axis") - elif b.ndim == 0: - raise TypeError("b must have at least one (batch) axis") - - core_a = a[0].type() - core_b = b[0].type() - core_dot = dot(core_a, core_b) - return vectorize_graph(core_dot, replace={core_a: a, core_b: b}) - - -def batched_tensordot(x, y, axes=2): - """Compute a batched tensordot product. - - A hybrid of batched_dot and tensordot, this function computes the - tensordot product between the two tensors, by iterating over the - first dimension to perform a sequence of tensordots. - - Parameters - ---------- - x: TensorVariable - A tensor with sizes e.g.: for 3D (dim1, dim3, dim2) - y: TensorVariable - A tensor with sizes e.g.: for 3D (dim1, dim2, dim4) - axes: int or array-like of length 2 - If an integer, the number of axes to sum over. - If an array, it must have two array elements containing the axes to sum - over in each tensor. - - If an integer i, it is converted to an array containing - the last i dimensions of the first tensor and the first - i dimensions of the second tensor (excluding the first - (batch) dimension): - axes = [list(range(a.ndim - i, b.ndim)), list(range(1,i+1))] - - If an array, its two elements must contain compatible axes - of the two tensors. For example, [[1, 2], [2, 4]] means sum - over the 2nd and 3rd axes of a and the 3rd and 5th axes of b. - (Remember axes are zero-indexed!) The 2nd axis of a and the - 3rd axis of b must have the same shape; the same is true for - the 3rd axis of a and the 5th axis of b. - - Like tensordot, this function uses a series of dimshuffles and - reshapes to reduce the tensor dot product to a matrix or vector - dot product. Finally, it calls batched_dot to compute the result. - """ - warnings.warn( - "batched_tensordot is deprecated. " - "Use `tensordot` in conjuction with `tensor.vectorize` or `graph.replace.vectorize_graph`", - FutureWarning, - ) - - if isinstance(axes, int): - core_axes = axes - else: - # Convert batched axes to core axes - core_axes_a = [a - 1 for a in normalize_axis_tuple(axes[0], x.type.ndim)] - core_axes = [a - 1 for a in normalize_axis_tuple(axes[1], y.type.ndim)] - core_axes = [core_axes_a, core_axes] - - core_x = x[0].type() - core_y = y[0].type() - core_tensordot = tensordot(core_x, core_y, axes=core_axes) - - return vectorize_graph(core_tensordot, replace={core_x: x, core_y: y}) diff --git a/pytensor/tensor/blas/__init__.py b/pytensor/tensor/blas/__init__.py new file mode 100644 index 0000000000..d5e77f51b1 --- /dev/null +++ b/pytensor/tensor/blas/__init__.py @@ -0,0 +1,145 @@ +"""Ops for using BLAS calls + +BLAS = Basic Linear Algebra Subroutines +Learn more about BLAS here: + http://www.netlib.org/blas/blast-forum/ +The standard BLAS libraries implement what is called "legacy BLAS" in that +document. + +This documentation describes PyTensor's BLAS optimization pipeline. + +Where there is a discrepancy between how things do work and how they *should* +work, both aspects should be documented. + +There are four kinds of BLAS Ops in PyTensor: + - Python implementations (this file) + - SciPy-based (blas_scipy) + - C-based (blas_c) + +Notes +----- +Unfortunately (because it's confusing) this file currently contains Ops +that contain both Python and C versions. I think it would be better to +move the C implementations to blas_c so that this file is pure Python. +-JB + + +Ops +=== + +GEMM: Dot22, Dot22Scalar, GemmRelated, Gemm +------------------------------------------- + +The BLAS GEMM operation implements Z <- a X Y + b Z, +where Z, X and Y are matrices, and a and b are scalars. + +Dot22 is a GEMM where a=1, b=0, and Z is allocated every time. + +Dot22Scalar is a GEMM where b=0 and Z is allocated every time. + +Gemm is a GEMM in all its generality. + +In the future we can refactor the GemmRelated, Gemm, Dot22 and +Dot22Scalar Ops into a single Op. That new Op (Gemm2) is basically a +normal Gemm, but with an additional configuration variable that says +to ignore the input Z. Setting that configuration variable to True +would make Gemm2 equivalent to the current Dot22 and Dot22Scalar. +This would make the file a lot easier to read, and save a few hundred +lines of library, to say nothing of testing and documentation. + + +GEMV: Gemv +---------- + +The BLAS GEMV operation implements Z <- a X Y + b Z, +where X is a matrix, Y, and Z are vectors, and a and b are scalars. + + +GER: Ger +-------- + +The BLAS GER operation implements Z <- a X' Y + Z, +where X and Y are vectors, and matrix Z gets a rank-1 update. + + +Other Notable BLAS-related Ops +------------------------------ + +SYRK is another useful special case of GEMM. Particularly SYRK preserves +symmetry in the matrix that it updates. See how the linear-algebra module uses +symmetry hints before implementing this Op, so that this Op is compatible with +that system. + + +Optimizations associated with these BLAS Ops are in tensor.rewriting.blas + +""" + +# Re-export everything for backward compatibility. +# All public symbols that were previously in pytensor.tensor.blas +# must remain importable from this path. + +from pytensor.tensor.blas._core import ( + _ldflags, + _logger, + blas_header_text, + blas_header_version, + ldflags, + must_initialize_y_gemv, + view_roots, +) +from pytensor.tensor.blas.batched import ( + BatchedDot, + _batched_dot, + batched_dot, + batched_tensordot, +) +from pytensor.tensor.blas.gemm import ( + Dot22, + Dot22Scalar, + Gemm, + GemmRelated, + _dot22, + _dot22scalar, + gemm, + gemm_inplace, + gemm_no_inplace, +) +from pytensor.tensor.blas.gemv import Gemv, gemv, gemv_inplace, gemv_no_inplace +from pytensor.tensor.blas.ger import Ger, ger, ger_destructive + + +__all__ = [ + # Core utilities + "view_roots", + "must_initialize_y_gemv", + "ldflags", + "_ldflags", + "_logger", + "blas_header_text", + "blas_header_version", + # Gemv + "Gemv", + "gemv_no_inplace", + "gemv_inplace", + "gemv", + # Ger + "Ger", + "ger", + "ger_destructive", + # GemmRelated / Gemm / Dot22 / Dot22Scalar + "GemmRelated", + "Gemm", + "gemm_inplace", + "gemm_no_inplace", + "gemm", + "Dot22", + "_dot22", + "Dot22Scalar", + "_dot22scalar", + # BatchedDot + "BatchedDot", + "_batched_dot", + "batched_dot", + "batched_tensordot", +] diff --git a/pytensor/tensor/blas/_core.py b/pytensor/tensor/blas/_core.py new file mode 100644 index 0000000000..9b2b76ed79 --- /dev/null +++ b/pytensor/tensor/blas/_core.py @@ -0,0 +1,190 @@ +import functools +import logging +import shlex +from pathlib import Path + +import numpy as np + +from pytensor.configdefaults import config +from pytensor.graph import Variable + + +_logger = logging.getLogger("pytensor.tensor.blas") + + +def view_roots(node: Variable) -> list[Variable]: + """Return the leaves from a search through consecutive view-maps.""" + owner = node.owner + if owner is not None: + try: + vars_to_views = {owner.outputs[o]: i for o, i in owner.op.view_map.items()} + except AttributeError: + return [node] + if node in vars_to_views: + answer = [] + for i in vars_to_views[node]: + answer += view_roots(owner.inputs[i]) + return answer + else: + return [node] + else: + return [node] + + +def must_initialize_y_gemv(): + # Check whether Scipy GEMV could output nan if y in not initialized + from scipy.linalg.blas import get_blas_funcs + + if must_initialize_y_gemv._result is None: + y = np.full((2,), np.nan) + x = np.ones((2,)) + A = np.ones((2, 2)) + gemv = get_blas_funcs("gemv", dtype=y.dtype) + gemv(1.0, A.T, x, 0.0, y, overwrite_y=True, trans=True) + must_initialize_y_gemv._result = np.isnan(y).any() + + return must_initialize_y_gemv._result + + +must_initialize_y_gemv._result = None # type: ignore + + +def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False): + """Extract a list of compilation flags from config.blas__ldflags. + + Depending on the options, different type of flags will be kept. + It returns a list of libraries against which an Op's object file + should be linked to benefit from a BLAS implementation. + + Parameters + ---------- + libs : bool, optional + Extract flags starting with "-l" (the default is True). + libs_dir : bool, optional + Extract flags starting with "-L" (the default is False). + include_dir : bool, optional + Extract flags starting with "-I" (the default is False). + flags: bool, optional + Extract all the other flags (the default is False). + + Returns + ------- + list of strings + Extracted flags. + + """ + ldflags_str = config.blas__ldflags + return _ldflags( + ldflags_str=ldflags_str, + libs=libs, + flags=flags, + libs_dir=libs_dir, + include_dir=include_dir, + ) + + +@functools.cache +def _ldflags( + ldflags_str: str, libs: bool, flags: bool, libs_dir: bool, include_dir: bool +) -> list[str]: + """Extract list of compilation flags from a string. + + Depending on the options, different type of flags will be kept. + + Parameters + ---------- + ldflags_str : string + The string to process. Typically, this will be the content of + `config.blas__ldflags`. + libs : bool + Extract flags starting with "-l". + flags: bool + Extract all the other flags. + libs_dir: bool + Extract flags starting with "-L". + include_dir: bool + Extract flags starting with "-I". + + Returns + ------- + list of strings + Extracted flags. + + """ + rval = [] + if libs_dir: + found_dyn = False + dirs = [x[2:] for x in shlex.split(ldflags_str) if x.startswith("-L")] + l = _ldflags( + ldflags_str=ldflags_str, + libs=True, + flags=False, + libs_dir=False, + include_dir=False, + ) + for d in dirs: + for f in Path(d.strip('"')).iterdir(): + if f.suffix in {".so", ".dylib", ".dll"}: + if any(f.stem.find(ll) >= 0 for ll in l): + found_dyn = True + # Special treatment of clang framework. Specifically for MacOS Accelerate + if "-framework" in l and "Accelerate" in l: + found_dyn = True + if not found_dyn and dirs: + _logger.warning( + "We did not find a dynamic library in the " + "library_dir of the library we use for blas. If you use " + "ATLAS, make sure to compile it with dynamics library." + ) + + split_flags = shlex.split(ldflags_str) + skip = False + for pos, t in enumerate(split_flags): + if skip: + skip = False + continue + # Remove extra quote. + if (t.startswith("'") and t.endswith("'")) or ( + t.startswith('"') and t.endswith('"') + ): + t = t[1:-1] + + try: + t0, t1 = t[0], t[1] + assert t0 == "-" or Path(t).exists() + except Exception: + raise ValueError(f'invalid token "{t}" in ldflags_str: "{ldflags_str}"') + if t == "-framework": + skip = True + # Special treatment of clang framework. Specifically for MacOS Accelerate + # The clang framework implicitly adds: header dirs, libraries, and library dirs. + # If we choose to always return these flags, we run into a huge deal amount of + # incompatibilities. For this reason, we only return the framework if libs are + # requested. + if ( + libs + and len(split_flags) >= pos + and split_flags[pos + 1] == "Accelerate" + ): + # We only add the Accelerate framework, but in the future we could extend it to + # other frameworks + rval.append(t) + rval.append(split_flags[pos + 1]) + elif libs_dir and t1 == "L": + rval.append(t[2:]) + elif include_dir and t1 == "I": + raise ValueError( + "Include dirs are not used for blas. We disable" + " this as this can hide other headers and this" + " is not wanted.", + t, + ) + elif libs and t1 == "l": # example -lmkl + rval.append(t[2:]) + elif flags and t1 not in ("L", "I", "l"): # example -openmp + rval.append(t) + elif flags and t1 == "L": + # to find it when we load the compiled op if the env of the + # used is not well configured. + rval.append("-Wl,-rpath," + t[2:]) + return rval diff --git a/pytensor/tensor/blas/batched.py b/pytensor/tensor/blas/batched.py new file mode 100644 index 0000000000..59a4036555 --- /dev/null +++ b/pytensor/tensor/blas/batched.py @@ -0,0 +1,460 @@ +"""BatchedDot Op and user-facing batched_dot/batched_tensordot functions. + +BatchedDot computes: batched_dot(a, b)[i] = dot(a[i], b[i]) +""" + +import warnings + +import numpy as np +from numpy.lib.array_utils import normalize_axis_tuple + +import pytensor.scalar +from pytensor.gradient import DisconnectedType, disconnected_type +from pytensor.graph import vectorize_graph +from pytensor.graph.basic import Apply +from pytensor.link.c.op import COp +from pytensor.tensor.basic import as_tensor_variable, cast +from pytensor.tensor.blas._core import blas_header_text, ldflags +from pytensor.tensor.blas_headers import blas_header_version +from pytensor.tensor.math import dot, tensordot +from pytensor.tensor.shape import specify_broadcastable +from pytensor.tensor.type import DenseTensorType, tensor + + +class BatchedDot(COp): + """ + Computes a batch matrix-matrix dot with tensor3 variables + + batched_dot(a, b)[i] = dot(a[i], b[i]) + """ + + __props__ = () + gufunc_signature = "(b,m,k),(b,k,n)->(b,m,n)" + + def make_node(self, x, y): + x = as_tensor_variable(x) + y = as_tensor_variable(y) + + if not ( + isinstance(x.type, DenseTensorType) and isinstance(y.type, DenseTensorType) + ): + raise NotImplementedError("Only dense tensor types are supported") + + if not (x.type.ndim == 3 and y.type.ndim == 3): + raise TypeError( + f"Inputs must have 3 ndim, but got {x.type.ndim} and {y.type.ndim}. " + "Consider calling batched_dot instead." + ) + + def extract_static_dim(dim_x, dim_y): + dims = {dim_x, dim_y} - {None} + if len(dims) > 1: + # BatchedDot doesn't allow broadcasting + raise ValueError( + f"Static dimensions of BatchedDot don't match, got {x.type.shape} and {y.type.shape}" + ) + elif not dims: + return None + else: + return dims.pop() + + x_batch_dim, x_row_dim, x_sum_dim = x.type.shape + y_batch_dim, y_sum_dim, y_col_dim = y.type.shape + batch_dim = extract_static_dim(x_batch_dim, y_batch_dim) + # Raise if static sum dimensions do not match + _ = extract_static_dim(x_sum_dim, y_sum_dim) + out_shape = (batch_dim, x_row_dim, y_col_dim) + + # Change dtype if needed + dtype = pytensor.scalar.upcast(x.type.dtype, y.type.dtype) + x, y = cast(x, dtype), cast(y, dtype) + out = tensor(dtype=dtype, shape=out_shape) + return Apply(self, [x, y], [out]) + + def perform(self, node, inp, out): + x, y = inp + (z,) = out + + if x.shape[0] != y.shape[0]: + raise TypeError( + f"Inputs [{', '.join(map(str, inp))}] must have the" + f" same size in axis 0, but have sizes [{', '.join(str(i.shape[0]) for i in inp)}]." + ) + + z[0] = np.matmul(x, y) + + def c_support_code(self, **kwargs): + batch_gemm_defn = """ + template + bool batch_gemm(void (*gemm)(char*, char*, const int*, const int*, const int*, const dtype*, const dtype*, const int*, const dtype*, const int*, const dtype*, dtype*, const int*), + int type_size, PyArrayObject* xs, PyArrayObject* ys, + PyArrayObject* zs) { + npy_intp *Nx = PyArray_DIMS(xs), *Sx = PyArray_STRIDES(xs); + npy_intp *Ny = PyArray_DIMS(ys), *Sy = PyArray_STRIDES(ys); + npy_intp *Nz = PyArray_DIMS(zs), *Sz = PyArray_STRIDES(zs); + + if (Nx[0] != Ny[0]) { + PyErr_Format(PyExc_ValueError, + "Shape mismatch: batch sizes unequal." + " x.shape is (%d, %d, %d)," + " y.shape is (%d, %d, %d).", + Nx[0], Nx[1], Nx[2], + Ny[0], Ny[1], Ny[2]); + return 1; + } + + if (Nx[2] != Ny[1]) { + PyErr_Format(PyExc_ValueError, + "Shape mismatch: summation axis sizes unequal." + " x.shape is (%d, %d, %d)," + " y.shape is (%d, %d, %d).", + Nx[0], Nx[1], Nx[2], + Ny[0], Ny[1], Ny[2]); + return 1; + } + + /* encode the stride structure of _x,_y,_z into a single integer. */ + int unit = 0; + unit |= ((Sx[2] == type_size || Nx[2] == 1) ? 0x0 : (Sx[1] == type_size || Nx[1]==1) ? 0x1 : 0x2) << 8; + unit |= ((Sy[2] == type_size || Ny[2] == 1) ? 0x0 : (Sy[1] == type_size || Ny[1]==1) ? 0x1 : 0x2) << 4; + unit |= ((Sz[2] == type_size || Nz[2] == 1) ? 0x0 : (Sz[1] == type_size || Nz[1]==1) ? 0x1 : 0x2) << 0; + + /* create appropriate strides for malformed matrices that are row or column + * vectors, or empty matrices. + * In that case, the value of the stride does not really matter, but + * some versions of BLAS insist that: + * - they are not smaller than the number of elements in the array, + * - they are not 0. + */ + int sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : (Nx[2] + 1); + int sx_2 = (Nx[2] > 1) ? Sx[2]/type_size : (Nx[1] + 1); + int sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : (Ny[2] + 1); + int sy_2 = (Ny[2] > 1) ? Sy[2]/type_size : (Ny[1] + 1); + int sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : (Nz[2] + 1); + int sz_2 = (Nz[2] > 1) ? Sz[2]/type_size : (Nz[1] + 1); + + dtype* x = (dtype*)PyArray_DATA(xs); + dtype* y = (dtype*)PyArray_DATA(ys); + dtype* z = (dtype*)PyArray_DATA(zs); + + dtype a = 1.0; + dtype b = 0.0; + char N = 'N'; + char T = 'T'; + int Nz1 = Nz[1], Nz2 = Nz[2], Nx2 = Nx[2]; + + // loop over batch axis + for (int i = 0; i < Nz[0]; i++) { + switch(unit) + { + case 0x000: gemm(&N, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_1, &b, z, &sz_1); break; + case 0x100: gemm(&N, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_1, x, &sx_2, &b, z, &sz_1); break; + case 0x010: gemm(&T, &N, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_1, &b, z, &sz_1); break; + case 0x110: gemm(&T, &T, &Nz2, &Nz1, &Nx2, &a, y, &sy_2, x, &sx_2, &b, z, &sz_1); break; + case 0x001: gemm(&T, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_1, &b, z, &sz_2); break; + case 0x101: gemm(&N, &T, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_1, &b, z, &sz_2); break; + case 0x011: gemm(&T, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_1, y, &sy_2, &b, z, &sz_2); break; + case 0x111: gemm(&N, &N, &Nz1, &Nz2, &Nx2, &a, x, &sx_2, y, &sy_2, &b, z, &sz_2); break; + default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); return 1; + }; + x += Sx[0] / type_size; + y += Sy[0] / type_size; + z += Sz[0] / type_size; + } + + return 0; + } + """ + return blas_header_text() + batch_gemm_defn + + def c_libraries(self, **kwargs): + return ldflags() + + def c_compile_args(self, **kwargs): + return ldflags(libs=False, flags=True) + + def c_lib_dirs(self, **kwargs): + return ldflags(libs=False, libs_dir=True) + + def c_header_dirs(self, **kwargs): + return ldflags(libs=False, include_dir=True) + + def c_code(self, node, name, inp, out, sub): + # Can only compile if linked to blas libraries + if len(self.c_libraries()) <= 0: + raise NotImplementedError() + + _x, _y = inp + (_z,) = out + fail = sub["fail"] + + # generate contiguity condition + def contiguous(var, ndim): + strides = f"PyArray_STRIDES({var})" + if ndim == 1: + return f"{strides}[0] == type_size" + ands = " && ".join( + f"{strides}[{i}] > 0 && {strides}[{i}] % type_size == 0" + for i in range(1, ndim) + ) + ors = " || ".join(f"{strides}[{i}] == type_size" for i in range(1, ndim)) + return f"{ands} && ({ors})" + + x_ndim, y_ndim, z_ndim = ( + node.inputs[0].ndim, + node.inputs[1].ndim, + node.outputs[0].ndim, + ) + + # generate code to allocate output based on runtime input shapes + z_dims = [ + f"PyArray_DIMS({_x})[0]", + f"PyArray_DIMS({_x})[1]", + f"PyArray_DIMS({_y})[2]", + ] + + z_shape_correct = " && ".join( + f"PyArray_DIMS({_z})[{i}] == {dim}" for i, dim in enumerate(z_dims) + ) + z_shape = ", ".join(z_dims) + z_contiguous = contiguous(_z, z_ndim) + allocate = f""" + if (NULL == {_z} || !({z_shape_correct}) || !({z_contiguous})) + {{ + npy_intp dims[{z_ndim}] = {{{z_shape}}}; + Py_XDECREF({_z}); + {_z} = (PyArrayObject*)PyArray_SimpleNew( + {z_ndim}, dims, PyArray_TYPE({_x})); + if(!{_z}) {{ + PyErr_SetString(PyExc_MemoryError, + "failed to alloc BatchedDot output"); + {fail} + }} + }} + """ + + # code to reallocate inputs contiguously if necessary + contiguate = [] + for var, ndim in [(_x, x_ndim), (_y, y_ndim)]: + _contiguous = contiguous(var, ndim) + contiguate.append( + f""" + if (!({_contiguous})) {{ + PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy({var}); + if (!_copy) + {fail} + Py_XDECREF({var}); + {var} = _copy; + }} + """ + ) + contiguate = "\n".join(contiguate) + + return f""" + int type_num = PyArray_DESCR({_x})->type_num; + int type_size = PyArray_ITEMSIZE({_x}); // in bytes + + if (PyArray_NDIM({_x}) != 3) {{ + PyErr_Format(PyExc_NotImplementedError, + "rank(x) != 3. rank(x) is %d.", + PyArray_NDIM({_x})); + {fail}; + }} + if (PyArray_NDIM({_y}) != 3) {{ + PyErr_Format(PyExc_NotImplementedError, + "rank(y) != 3. rank(y) is %d.", + PyArray_NDIM({_y})); + {fail}; + }} + if ({_z} && PyArray_NDIM({_z}) != 3) {{ + PyErr_Format(PyExc_NotImplementedError, + "rank(z) != 3. rank(z) is %d.", + PyArray_NDIM({_z})); + {fail}; + }} + + // allocate output + {allocate} + // reallocate any noncontiguous arrays or arrays with invalid strides + {contiguate} + + if ((PyArray_DESCR({_x})->type_num != NPY_DOUBLE) + && (PyArray_DESCR({_x})->type_num != NPY_FLOAT)) + {{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); {fail};}} + + if ((PyArray_DESCR({_y})->type_num != NPY_DOUBLE) + && (PyArray_DESCR({_y})->type_num != NPY_FLOAT)) + {{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); {fail};}} + + if ((PyArray_DESCR({_z})->type_num != NPY_DOUBLE) + && (PyArray_DESCR({_z})->type_num != NPY_FLOAT)) + {{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); {fail};}} + + if ((PyArray_DESCR({_x})->type_num != PyArray_DESCR({_y})->type_num) + ||(PyArray_DESCR({_x})->type_num != PyArray_DESCR({_z})->type_num)) + {{ PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); {fail}; }} + + switch (type_num) + {{ + case NPY_FLOAT: + if (batch_gemm(sgemm_, type_size, {_x}, {_y}, {_z})) {{ + {fail}; + }} + break; + case NPY_DOUBLE: + if (batch_gemm(dgemm_, type_size, {_x}, {_y}, {_z})) {{ + {fail}; + }} + break; + }} + """ + + def c_code_cache_version(self): + return (6, blas_header_version()) + + def pullback(self, inp, outputs, grads): + x, y = inp + (gz,) = grads + + xgrad = _batched_dot(gz, y.dimshuffle(0, 2, 1)) + ygrad = _batched_dot(x.dimshuffle(0, 2, 1), gz) + + # If x or y contain broadcastable dimensions but only one of + # them know that a matching dimensions is broadcastable, the + # above code don't always return the right broadcast pattern. + # This cause problem down the road. See gh-1461. + if xgrad.broadcastable != x.broadcastable: + xgrad = specify_broadcastable( + xgrad, *(ax for (ax, b) in enumerate(x.type.broadcastable) if b) + ) + if ygrad.broadcastable != y.broadcastable: + ygrad = specify_broadcastable( + ygrad, *(ax for (ax, b) in enumerate(y.type.broadcastable) if b) + ) + + return xgrad, ygrad + + def pushforward(self, inputs, outputs, eval_points): + assert len(inputs) == 2 + assert len(eval_points) == 2 + if isinstance(eval_points[0].type, DisconnectedType) and isinstance( + eval_points[1].type, DisconnectedType + ): + return [disconnected_type()] + + if not isinstance(eval_points[0].type, DisconnectedType): + t1 = self(eval_points[0], inputs[1]) + if not isinstance(eval_points[1].type, DisconnectedType): + t2 = self(inputs[0], eval_points[1]) + + if not isinstance(eval_points[0].type, DisconnectedType) and not isinstance( + eval_points[1].type, DisconnectedType + ): + return [t1 + t2] + elif not isinstance(eval_points[0].type, DisconnectedType): + return [t1] + else: + return [t2] + + def infer_shape(self, fgraph, node, shapes): + xshp, yshp = shapes + return [xshp[:-1] + yshp[2:]] + + +_batched_dot = BatchedDot() + + +def batched_dot(a, b): + """Compute the batched dot product of two variables. + + I.e.: + + batched_dot(a, b)[i] = dot(a[i], b[i]) + + Note that this batched_dot function does one of three things, in the + following sequence: + + 1. If either a or b is a vector, it returns the batched elementwise + product without calling the PyTensor BatchedDot op. + + 2. If both a and b have either 2 or 3 dimensions, it calls PyTensor's + BatchedDot op on a and b. + + 3. If either a or b has more than 3 dimensions, it calls PyTensor's + batched_tensordot function with appropriate axes. The + batched_tensordot function expresses high-dimensional batched + dot products in terms of batched matrix-matrix dot products, so + it may be possible to further optimize for performance. + """ + warnings.warn( + "batched_dot is deprecated. " + "Use `dot` in conjunction with `tensor.vectorize` or `graph.replace.vectorize_graph`", + FutureWarning, + ) + a, b = as_tensor_variable(a), as_tensor_variable(b) + + if a.ndim == 0: + raise TypeError("a must have at least one (batch) axis") + elif b.ndim == 0: + raise TypeError("b must have at least one (batch) axis") + + core_a = a[0].type() + core_b = b[0].type() + core_dot = dot(core_a, core_b) + return vectorize_graph(core_dot, replace={core_a: a, core_b: b}) + + +def batched_tensordot(x, y, axes=2): + """Compute a batched tensordot product. + + A hybrid of batched_dot and tensordot, this function computes the + tensordot product between the two tensors, by iterating over the + first dimension to perform a sequence of tensordots. + + Parameters + ---------- + x: TensorVariable + A tensor with sizes e.g.: for 3D (dim1, dim3, dim2) + y: TensorVariable + A tensor with sizes e.g.: for 3D (dim1, dim2, dim4) + axes: int or array-like of length 2 + If an integer, the number of axes to sum over. + If an array, it must have two array elements containing the axes to sum + over in each tensor. + + If an integer i, it is converted to an array containing + the last i dimensions of the first tensor and the first + i dimensions of the second tensor (excluding the first + (batch) dimension): + axes = [list(range(a.ndim - i, b.ndim)), list(range(1,i+1))] + + If an array, its two elements must contain compatible axes + of the two tensors. For example, [[1, 2], [2, 4]] means sum + over the 2nd and 3rd axes of a and the 3rd and 5th axes of b. + (Remember axes are zero-indexed!) The 2nd axis of a and the + 3rd axis of b must have the same shape; the same is true for + the 3rd axis of a and the 5th axis of b. + + Like tensordot, this function uses a series of dimshuffles and + reshapes to reduce the tensor dot product to a matrix or vector + dot product. Finally, it calls batched_dot to compute the result. + """ + warnings.warn( + "batched_tensordot is deprecated. " + "Use `tensordot` in conjuction with `tensor.vectorize` or `graph.replace.vectorize_graph`", + FutureWarning, + ) + + if isinstance(axes, int): + core_axes = axes + else: + # Convert batched axes to core axes + core_axes_a = [a - 1 for a in normalize_axis_tuple(axes[0], x.type.ndim)] + core_axes = [a - 1 for a in normalize_axis_tuple(axes[1], y.type.ndim)] + core_axes = [core_axes_a, core_axes] + + core_x = x[0].type() + core_y = y[0].type() + core_tensordot = tensordot(core_x, core_y, axes=core_axes) + + return vectorize_graph(core_tensordot, replace={core_x: x, core_y: y}) diff --git a/pytensor/tensor/blas/gemm.py b/pytensor/tensor/blas/gemm.py new file mode 100644 index 0000000000..3f2aec1fd1 --- /dev/null +++ b/pytensor/tensor/blas/gemm.py @@ -0,0 +1,860 @@ +"""BLAS GEMM family: Gemm, Dot22, Dot22Scalar, and the GemmRelated base class. + +Gemm computes: b*z + a*dot(x,y) +Dot22 computes: dot(x, y) (matrix-matrix, BLAS-accelerated) +Dot22Scalar computes: scalar * dot(x, y) +""" + +import numpy as np + +import pytensor.scalar +from pytensor.graph.basic import Apply +from pytensor.graph.utils import InconsistencyError, MethodNotDefined +from pytensor.link.c.op import COp +from pytensor.link.c.params_type import ParamsType +from pytensor.printing import FunctionPrinter, pprint +from pytensor.scalar import bool as bool_t +from pytensor.tensor.basic import as_tensor_variable +from pytensor.tensor.blas._core import ( + blas_header_text, + blas_header_version, + ldflags, + view_roots, +) +from pytensor.tensor.type import DenseTensorType, tensor + + +class GemmRelated(COp): + """Base class for Gemm and Dot22. + + This class provides a kind of templated gemm Op. + + """ + + __props__: tuple[str, ...] = () + + def c_support_code(self, **kwargs): + # return cblas_header_text() + mod_str = """ + #ifndef MOD + #define MOD % + #endif + void compute_strides(npy_intp *shape, int N_shape, int type_size, npy_intp *res) { + int s; + res[N_shape - 1] = type_size; + for (int i = N_shape - 1; i > 0; i--) { + s = shape[i]; + res[i - 1] = res[i] * (s > 0 ? s : 1); + } + } + """ + return blas_header_text() + mod_str + + def c_headers(self, **kwargs): + return [] + + def c_libraries(self, **kwargs): + return ldflags() + + # code_cache_version is built by subclasses from + # build_gemm_version + + def c_compile_args(self, **kwargs): + return ldflags(libs=False, flags=True) + + def c_lib_dirs(self, **kwargs): + return ldflags(libs=False, libs_dir=True) + + def c_header_dirs(self, **kwargs): + return ldflags(libs=False, include_dir=True) + + declare_NS = """ + int unit = 0; + + int type_num = PyArray_DESCR(%(_x)s)->type_num; + int type_size = PyArray_ITEMSIZE(%(_x)s); // in bytes + + npy_intp* Nx = PyArray_DIMS(%(_x)s); + npy_intp* Ny = PyArray_DIMS(%(_y)s); + npy_intp* Nz = 0; //PyArray_DIMS(%(_zout)s); + + npy_intp* Sx = PyArray_STRIDES(%(_x)s); + npy_intp* Sy = PyArray_STRIDES(%(_y)s); + npy_intp* Sz = 0; //PyArray_STRIDES(%(_zout)s); + + //strides for x, y, z in dimensions 0, 1 + int sx_0, sx_1, sy_0, sy_1, sz_0, sz_1; + """ + + # implement if you don't have an inplace props + # setup_z_Nz_Sz = None + # otherwise implement + # setup_z_Nz_Sz_inplace = None + # setup_z_Nz_Sz_outplace = None + + check_xyz_rank2 = """ + if (PyArray_NDIM(%(_x)s) != 2) { + PyErr_Format(PyExc_NotImplementedError, + "rank(x) != 2. rank(x) is %%d.", + PyArray_NDIM(%(_x)s)); + %(fail)s; + } + if (PyArray_NDIM(%(_y)s) != 2) { + PyErr_Format(PyExc_NotImplementedError, + "rank(y) != 2. rank(y) is %%d.", PyArray_NDIM(%(_y)s)); + %(fail)s; + } + if (%(_zout)s && PyArray_NDIM(%(_zout)s) != 2) { + PyErr_Format(PyExc_NotImplementedError, + "rank(z) != 2. rank(z) is %%d.", PyArray_NDIM(%(_zout)s)); + %(fail)s; + } + """ + check_xyz_double_or_float = """ + if ((PyArray_DESCR(%(_x)s)->type_num != NPY_DOUBLE) + && (PyArray_DESCR(%(_x)s)->type_num != NPY_FLOAT)) + {PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;} + + if ((PyArray_DESCR(%(_y)s)->type_num != NPY_DOUBLE) + && (PyArray_DESCR(%(_y)s)->type_num != NPY_FLOAT)) + {PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;} + + if ((PyArray_DESCR(%(_zout)s)->type_num != NPY_DOUBLE) + && (PyArray_DESCR(%(_zout)s)->type_num != NPY_FLOAT)) + {PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;} + + if ((PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_y)s)->type_num) + ||(PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_zout)s)->type_num)) + { PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; } + """ + + # it is not necessary that a or b have the same type as x,y,z + check_ab_double_or_float = """ + if ((PyArray_DESCR(%(_a)s)->type_num != NPY_DOUBLE) + && (PyArray_DESCR(%(_a)s)->type_num != NPY_FLOAT)) + {PyErr_SetString(PyExc_NotImplementedError, "type(a) is not double or float"); %(fail)s;} + + if ((PyArray_DESCR(%(_b)s)->type_num != NPY_DOUBLE) + && (PyArray_DESCR(%(_b)s)->type_num != NPY_FLOAT)) + {PyErr_SetString(PyExc_NotImplementedError, "type(b) is not double or float"); %(fail)s;} + """ + + # broadcast_xy = None + + check_dims = """ + if (Nx[0] !=1 && Nz[0] != 1 && Nx[0] != Nz[0]) + { + PyErr_Format(PyExc_ValueError, + "Shape mismatch: x has %%ld rows but z has %%ld rows", + (long int)Nx[0], (long int)Nz[0]); + %(fail)s; + } + if (Nx[1] != Ny[0]) + { + PyErr_Format(PyExc_ValueError, + "Shape mismatch: x has %%ld cols (and %%ld rows) but y has %%ld rows (and %%ld cols)", + (long int)Nx[1], (long int)Nx[0], (long int)Ny[0], (long int)Ny[1]); + %(fail)s; + } + if (Ny[1] != 1 && Nz[1]!= 1 && Ny[1] != Nz[1]) + { + PyErr_Format(PyExc_ValueError, + "Shape mismatch: y has %%ld cols but z has %%ld cols", + (long int)Ny[1], (long int)Nz[1]); + %(fail)s; + } + + // We must not raise an error when Nx[1] == 0. This would disable cases + // that numpy.dot accept. + """ + + check_strides = """ + /* + If some matrices are not contiguous on either dimensions, + or have invalid strides, copy their content into a contiguous one + */ + if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size) + || ((Sx[0] != type_size) && (Sx[1] != type_size))) + { + PyArrayObject * _x_copy = (PyArrayObject *) PyArray_Copy(%(_x)s); + if (!_x_copy) + %(fail)s + Py_XDECREF(%(_x)s); + %(_x)s = _x_copy; + Sx = PyArray_STRIDES(%(_x)s); + if ((Sx[0] < 1) || (Sx[1] < 1)) { + compute_strides(Nx, 2, type_size, Sx); + } + } + + if ((Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size) + || ((Sy[0] != type_size) && (Sy[1] != type_size))) + { + PyArrayObject * _y_copy = (PyArrayObject *) PyArray_Copy(%(_y)s); + if (!_y_copy) + %(fail)s + Py_XDECREF(%(_y)s); + %(_y)s = _y_copy; + Sy = PyArray_STRIDES(%(_y)s); + if ((Sy[0] < 1) || (Sy[1] < 1)) { + compute_strides(Ny, 2, type_size, Sy); + } + } + + if ((Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size) + || ((Sz[0] != type_size) && (Sz[1] != type_size))) + { + PyArrayObject * _z_copy = (PyArrayObject *) PyArray_Copy(%(_zout)s); + if (!_z_copy) + %(fail)s + Py_XDECREF(%(_zout)s); + %(_zout)s = _z_copy; + Sz = PyArray_STRIDES(%(_zout)s); + if ((Sz[0] < 1) || (Sz[1] < 1)) { + compute_strides(Nz, 2, type_size, Sz); + } + } + """ + + encode_strides_in_unit = """ + /* + encode the stride structure of _x,_y,_zout into a single integer + */ + unit |= ((Sx[1] == type_size || Nx[1]==1) ? 0x0 : (Sx[0] == type_size || Nx[0]==1) ? 0x1 : 0x2) << 8; + unit |= ((Sy[1] == type_size || Ny[1]==1) ? 0x0 : (Sy[0] == type_size || Ny[0]==1) ? 0x1 : 0x2) << 4; + unit |= ((Sz[1] == type_size || Nz[1]==1) ? 0x0 : (Sz[0] == type_size || Nz[0]==1) ? 0x1 : 0x2) << 0; + """ + + compute_strides = """ + /* create appropriate strides for malformed matrices that are row or column + * vectors, or empty matrices. + * In that case, the value of the stride does not really matter, but + * some versions of BLAS insist that: + * - they are not smaller than the number of elements in the array, + * - they are not 0. + */ + sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : (Nx[1] + 1); + sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : (Nx[0] + 1); + sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : (Ny[1] + 1); + sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : (Ny[0] + 1); + sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : (Nz[1] + 1); + sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : (Nz[0] + 1); + """ + + begin_switch_typenum = """ + switch (type_num) + { + """ + + case_float = """ + case NPY_FLOAT: + { + """ + + # case_float_ab_constants = None + + case_float_gemm = """ + float* x = (float*)PyArray_DATA(%(_x)s); + float* y = (float*)PyArray_DATA(%(_y)s); + float* z = (float*)PyArray_DATA(%(_zout)s); + char N = 'N'; + char T = 'T'; + int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1]; + switch(unit) + { + case 0x000: sgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break; + case 0x100: sgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break; + case 0x010: sgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break; + case 0x110: sgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break; + case 0x001: sgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break; + case 0x101: sgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break; + case 0x011: sgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break; + case 0x111: sgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break; + default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s; + }; + """ + + case_double = """ + } + break; + case NPY_DOUBLE: + { + """ + + # case_double_ab_constants = None + + case_double_gemm = """ + double* x = (double*)PyArray_DATA(%(_x)s); + double* y = (double*)PyArray_DATA(%(_y)s); + double* z = (double*)PyArray_DATA(%(_zout)s); + char N = 'N'; + char T = 'T'; + int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1]; + switch(unit) + { + case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, + &sy_0, x, &sx_0, &b, z, &sz_0); break; + case 0x100: dgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, + &sy_0, x, &sx_1, &b, z, &sz_0); break; + case 0x010: dgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, + &sy_1, x, &sx_0, &b, z, &sz_0); break; + case 0x110: dgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, + &sy_1, x, &sx_1, &b, z, &sz_0); break; + case 0x001: dgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, + &sx_0, y, &sy_0, &b, z, &sz_1); break; + case 0x101: dgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, + &sx_1, y, &sy_0, &b, z, &sz_1); break; + case 0x011: dgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, + &sx_0, y, &sy_1, &b, z, &sz_1); break; + case 0x111: dgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, + &sx_1, y, &sy_1, &b, z, &sz_1); break; + default: PyErr_SetString(PyExc_ValueError, + "some matrix has no unit stride"); + %(fail)s; + }; + """ + + end_switch_typenum = """ + } + break; + } + """ + + def build_gemm_call(self): + if hasattr(self, "inplace"): + setup_z_Nz_Sz = f"if(%(params)s->inplace){{{self.setup_z_Nz_Sz_inplace}}}else{{{self.setup_z_Nz_Sz_outplace}}}" + else: + setup_z_Nz_Sz = self.setup_z_Nz_Sz + + return "".join( + ( + self.declare_NS, + self.check_xyz_rank2, + setup_z_Nz_Sz, + self.check_xyz_double_or_float, + self.check_ab_double_or_float, + self.broadcast_xy, + self.check_dims, + self.check_strides, + self.encode_strides_in_unit, + self.compute_strides, + self.begin_switch_typenum, + self.case_float, + self.case_float_ab_constants, + self.case_float_gemm, + self.case_double, + self.case_double_ab_constants, + self.case_double_gemm, + self.end_switch_typenum, + ) + ) + + def build_gemm_version(self): + return (14, blas_header_version()) + + +class Gemm(GemmRelated): + """In-place version of matrix-matrix multiplication (with accumulation). + + When a and b are scalars and x, y, and z are matrices, then + + gemm(z,a,x,y,b) + + is similar to + + b*z + a*dot(x,y) + + The difference between the two is that the top form is destructive + on z, whereas the bottom form is not. Gemm works in-place on the + storage associated with z, and the L{Variable} returned by Gemm + has a storage that will be aliased to the storage of the z + argument. Because of this in-place computation, an L{Apply} of + this op will destroy the L{Variable} z on which it operates. (See + L{DestructiveOps} for an explanation of what destroying means in + the context of pytensor graphs. See L{BlasLapackSupport} for more + optimized linear algebra operations.) + + """ + + E_rank = "gemm only works for rank 2" + E_scalar = "gemm requires scalar argument" + E_z_uniq = "argument z aliased to x or y" # TODO: justify / delete this + E_mixed = "gemm requires matching dtypes" + E_float = "gemm requires floating-point dtypes" + + __props__ = ("inplace",) + params_type = ParamsType( + inplace=bool_t, + ) + check_input = False + + def __init__(self, inplace): + self.inplace = inplace + if self.inplace: + self.destroy_map = {0: [0]} + + def __str__(self): + if self.inplace: + inplace_str = "inplace" + else: + inplace_str = "no_inplace" + return f"{self.__class__.__name__}{{{inplace_str}}}" + + def __setstate__(self, dct): + self.__dict__.update(dct) + + # Correctly reload older pickles where destroy_map were not + # saved + if "destroy_map" not in self.__dict__ and self.inplace: + self.destroy_map = {0: [0]} + + def __getstate__(self): + rval = self.__dict__.copy() + # Do not serialize the setup code, it will be restored in __setstate__ + # depending on the value of 'inplace' + rval.pop("setup_z_Nz_Sz", None) + return rval + + def make_node(self, *inputs): + inputs = list(map(as_tensor_variable, inputs)) + + if any(not isinstance(i.type, DenseTensorType) for i in inputs): + raise NotImplementedError("Only dense tensor types are supported") + + if len(inputs) != 5: + raise TypeError( + f"Wrong number of inputs for {self} (expected 5, got {len(inputs)})" + ) + z, a, x, y, b = inputs + + zr, xr, yr = (set(view_roots(i)) for i in (z, x, y)) + + # We want the gemm to be inplace. When this op is inplace, it + # declare to be inplace only on z. So to make it safe, we + # raise an error if z can be a view on x or y. + + # I don't know if PyTensor currently can support that case. As + # this case don't happen in our code, I won't spent time + # investigating this. So the assert is for safety. I also + # think there is another mechanism that would prevent this, + # but I don't what to modify old code and have chance to break + # something. + if self.inplace: + if zr.intersection(xr): + raise InconsistencyError(Gemm.E_z_uniq, (z, x)) + if zr.intersection(yr): + raise InconsistencyError(Gemm.E_z_uniq, (z, y)) + + if z.ndim != 2: + raise TypeError(Gemm.E_rank, z) + if a.ndim != 0: + raise TypeError(Gemm.E_scalar, a) + if x.ndim != 2: + raise TypeError(Gemm.E_rank, x) + if y.ndim != 2: + raise TypeError(Gemm.E_rank, y) + if b.ndim != 0: + raise TypeError(Gemm.E_scalar, b) + + if not (z.dtype == a.dtype == x.dtype == y.dtype == b.dtype): + raise TypeError(Gemm.E_mixed, (z.dtype, a.dtype, x.dtype, y.dtype, b.dtype)) + + if not z.dtype.startswith("float") and not z.dtype.startswith("complex"): + raise TypeError(Gemm.E_float, (z.dtype)) + + output = z.type() + return Apply(self, inputs, [output]) + + def perform(self, node, inp, out): + z, a, x, y, b = inp + (zout,) = out + assert a.shape == () + assert b.shape == () + if not self.inplace: + z = z.copy() # the original z will not be changed + if z.shape == (): + z.itemset(z * a + b * np.dot(x, y)) + zout[0] = z + else: + # Broadcast Z if needed + if (x.shape[0] > z.shape[0]) or (y.shape[1] > z.shape[1]): + z = np.broadcast_to( + z, (max(x.shape[0], z.shape[0]), max(y.shape[1], z.shape[1])) + ).copy() + if b == 0.0: + if a == 1.0: + z[:] = np.dot(x, y) + elif a == -1.0: + z[:] = -np.dot(x, y) + else: + z[:] = a * np.dot(x, y) + elif b == 1.0: + if a == 1.0: + z += np.dot(x, y) + elif a == -1.0: + z -= np.dot(x, y) + else: + z += a * np.dot(x, y) + else: + z *= b + z += a * np.dot(x, y) + zout[0] = z + + def infer_shape(self, fgraph, node, input_shapes): + z_shape, _, x_shape, y_shape, _ = input_shapes + return [ + ( + pytensor.scalar.maximum(z_shape[0], x_shape[0]), + pytensor.scalar.maximum(z_shape[1], y_shape[1]), + ) + ] + + setup_z_Nz_Sz_inplace = """ + // Needs broadcasting + if (PyArray_DIMS(%(_z)s)[0] < Nx[0] || PyArray_DIMS(%(_z)s)[1] < Ny[1]){ + + npy_intp dims[2]; + dims[0] = (PyArray_DIMS(%(_z)s)[0] >= Nx[0]) ? PyArray_DIMS(%(_z)s)[0] : Nx[0]; + dims[1] = (PyArray_DIMS(%(_z)s)[1] >= Ny[1]) ? PyArray_DIMS(%(_z)s)[1] : Ny[1]; + + // Check if we need to allocate new array + if((NULL == %(_zout)s) + || (PyArray_DIMS(%(_zout)s)[0] != dims[0]) + || (PyArray_DIMS(%(_zout)s)[1] != dims[1])) + { + // fprintf(stderr, "Gemm Allocating z output array with shape (%%i %%i)\\n", dims[0], dims[1]); + Py_XDECREF(%(_zout)s); + %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_z)s)); + } + + // fprintf(stderr, "Gemm Broadcasting Z into shape (%%i %%i)\\n", dims[0], dims[1]); + if(PyArray_CopyInto(%(_zout)s, %(_z)s) == -1) + { + %(fail)s; + } + + } else { + if (%(_zout)s != %(_z)s) + { + Py_XDECREF(%(_zout)s); + %(_zout)s = %(_z)s; + Py_INCREF(%(_zout)s); + } + } + + Nz = PyArray_DIMS(%(_zout)s); + Sz = PyArray_STRIDES(%(_zout)s); + """ + + setup_z_Nz_Sz_outplace = """ + npy_intp dims[2]; + dims[0] = (PyArray_DIMS(%(_z)s)[0] >= Nx[0]) ? PyArray_DIMS(%(_z)s)[0] : Nx[0]; + dims[1] = (PyArray_DIMS(%(_z)s)[1] >= Ny[1]) ? PyArray_DIMS(%(_z)s)[1] : Ny[1]; + + // Check if we need to allocate new array + if ((NULL == %(_zout)s) + || (PyArray_DIMS(%(_zout)s)[0] != dims[0]) + || (PyArray_DIMS(%(_zout)s)[1] != dims[1])) + { + Py_XDECREF(%(_zout)s); + %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_z)s)); + // fprintf(stderr, "Gemm Allocating z output array with shape (%%i %%i)\\n", dims[0], dims[1]); + if(!%(_zout)s) { + PyErr_SetString(PyExc_MemoryError, + "failed to alloc gemm_no_inplace output"); + %(fail)s + } + } + + // fprintf(stderr, "Gemm Broadcasting Z into shape (%%i %%i)\\n", dims[0], dims[1]); + if(PyArray_CopyInto(%(_zout)s, %(_z)s) == -1) + { + %(fail)s + } + + Nz = PyArray_DIMS(%(_zout)s); + Sz = PyArray_STRIDES(%(_zout)s); + """ + + broadcast_xy = """ + // Broadcast X if needed + if (Nz[0] > Nx[0]) + { + npy_intp dims[2]; + dims[0] = Nz[0]; + dims[1] = Nx[1]; + // fprintf(stderr, "Gemm Broadcasting X into shape (%%i %%i)\\n", dims[0], dims[1]); + PyArrayObject *x_new = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_x)s)); + if(!x_new) { + PyErr_SetString(PyExc_MemoryError, + "failed to alloc gemm_inplace input"); + %(fail)s + } + + if(PyArray_CopyInto(x_new, %(_x)s) == -1) + { + %(fail)s + } + + Py_DECREF(%(_x)s); + %(_x)s = x_new; + + Nx = PyArray_DIMS(%(_x)s); + Sx = PyArray_STRIDES(%(_x)s); + } + + // Broadcast Y if needed + if (Nz[1] > Ny[1]) + { + npy_intp dims[2]; + dims[0] = Ny[0]; + dims[1] = Nz[1]; + // fprintf(stderr, "Gemm Broadcasting Y into shape (%%i %%i)\\n", dims[0], dims[1]); + PyArrayObject *y_new = (PyArrayObject*)PyArray_SimpleNew(2, dims, PyArray_TYPE(%(_x)s)); + if(!y_new) { + PyErr_SetString(PyExc_MemoryError, + "failed to alloc gemm_inplace input"); + %(fail)s + } + + if(PyArray_CopyInto(y_new, %(_y)s) == -1) + { + %(fail)s + } + + Py_DECREF(%(_y)s); + %(_y)s = y_new; + + Ny = PyArray_DIMS(%(_y)s); + Sy = PyArray_STRIDES(%(_y)s); + } + + """ + + case_float_ab_constants = """ + #define REAL float + float a = (PyArray_DESCR(%(_a)s)->type_num == NPY_FLOAT) + ? (REAL)(((float*)PyArray_DATA(%(_a)s))[0]) + : (REAL)(((double*)PyArray_DATA(%(_a)s))[0]); + float b = (PyArray_DESCR(%(_b)s)->type_num == NPY_FLOAT) ? + (REAL)(((float*)PyArray_DATA(%(_b)s))[0]) + : (REAL)(((double*)PyArray_DATA(%(_b)s))[0]); + #undef REAL + """ + case_double_ab_constants = """ + #define REAL double + double a = (PyArray_DESCR(%(_a)s)->type_num == NPY_FLOAT) + ? (REAL)(((float*)PyArray_DATA(%(_a)s))[0]) + : (REAL)(((double*)PyArray_DATA(%(_a)s))[0]); + double b = (PyArray_DESCR(%(_b)s)->type_num == NPY_FLOAT) ? + (REAL)(((float*)PyArray_DATA(%(_b)s))[0]) + : (REAL)(((double*)PyArray_DATA(%(_b)s))[0]); + #undef REAL + """ + + def c_code(self, node, name, inp, out, sub): + _z, _a, _x, _y, _b = inp + (_zout,) = out + if node.inputs[0].type.dtype.startswith("complex"): + raise MethodNotDefined(f"{self.__class__.__name__}.c_code") + full_code = self.build_gemm_call() % dict(locals(), **sub) + return full_code + + def c_code_cache_version(self): + gv = self.build_gemm_version() + if gv: + return (8, *gv) + else: + return gv + + +gemm_inplace = Gemm(inplace=True) +gemm_no_inplace = Gemm(inplace=False) +# For the user interface. PyTensor optimization will make them inplace +gemm = gemm_no_inplace +pprint.assign(gemm_inplace, FunctionPrinter(["gemm_inplace"])) +pprint.assign(gemm_no_inplace, FunctionPrinter(["gemm_no_inplace"])) + + +class Dot22(GemmRelated): + """Compute a matrix-matrix product. + + This is a specialization of the more general Dot(). + + """ + + check_input = False + + def make_node(self, x, y): + x = as_tensor_variable(x) + y = as_tensor_variable(y) + + if any(not isinstance(i.type, DenseTensorType) for i in (x, y)): + raise NotImplementedError("Only dense tensor types are supported") + + dtypes = ("float16", "float32", "float64", "complex64", "complex128") + if x.type.ndim != 2 or x.type.dtype not in dtypes: + raise TypeError(x) + if y.type.ndim != 2 or y.type.dtype not in dtypes: + raise TypeError(y) + if y.type.dtype != x.type.dtype: + raise TypeError("dtype mismatch to Dot22") + outputs = [tensor(dtype=x.type.dtype, shape=(x.type.shape[0], y.type.shape[1]))] + return Apply(self, [x, y], outputs) + + def perform(self, node, inputs, output_storage): + output_storage[0][0] = np.dot(*inputs) + + def infer_shape(self, fgraph, node, input_shapes): + return [[input_shapes[0][0], input_shapes[1][1]]] + + setup_z_Nz_Sz = """ + if ((NULL == %(_zout)s) + || (PyArray_DIMS(%(_zout)s)[0] != PyArray_DIMS(%(_x)s)[0]) + || (PyArray_DIMS(%(_zout)s)[1] != PyArray_DIMS(%(_y)s)[1])) + { + if (NULL != %(_zout)s) Py_XDECREF(%(_zout)s); + npy_intp dims[2]; + dims[0] = PyArray_DIMS(%(_x)s)[0]; + dims[1] = PyArray_DIMS(%(_y)s)[1]; + %(_zout)s = (PyArrayObject*)PyArray_SimpleNew(2, dims, + PyArray_TYPE(%(_x)s)); + //fprintf(stderr, "Dot Allocating %%i %%i\\n", dims[0], dims[1]); + if(!%(_zout)s) { + PyErr_SetString(PyExc_MemoryError, + "failed to alloc dot22 output"); + %(fail)s + } + } + Nz = PyArray_DIMS(%(_zout)s); + Sz = PyArray_STRIDES(%(_zout)s); + + """ + broadcast_xy = "" + check_ab_double_or_float = "" + case_float_ab_constants = """ + float a = 1.0; + float b = 0.0; + """ + case_double_ab_constants = """ + double a = 1.0; + double b = 0.0; + """ + + def c_code(self, node, name, inp, out, sub): # DEBUG + _x, _y = inp + (_zout,) = out + if node.inputs[0].type.dtype.startswith("complex"): + raise MethodNotDefined(f"{self.__class__.__name__}.c_code") + if len(self.c_libraries()) <= 0: + raise NotImplementedError() + full_code = self.build_gemm_call() % dict(locals(), **sub) + return full_code + + def c_code_cache_version(self): + gv = self.build_gemm_version() + if gv: + return (2, *gv) + else: + return gv + + +_dot22 = Dot22() + + +class Dot22Scalar(GemmRelated): + """Compute a matrix-matrix product. + + This is a specialization of the more general Dot() + Used to call optimized gemm implementation. + Also used to generate a gemm later. + compute scalar*dot(x,y). + + """ + + check_input = False + + def make_node(self, x, y, a): + if any(not isinstance(i.type, DenseTensorType) for i in (x, y, a)): + raise NotImplementedError("Only dense tensor types are supported") + + if a.ndim != 0: + raise TypeError(Gemm.E_scalar, a) + if x.ndim != 2: + raise TypeError(Gemm.E_rank, x) + if y.ndim != 2: + raise TypeError(Gemm.E_rank, y) + + if not (a.dtype == x.dtype == y.dtype): + raise TypeError( + "Dot22Scalar requires matching dtypes", (a.dtype, x.dtype, y.dtype) + ) + + if not a.dtype.startswith("float") and not a.dtype.startswith("complex"): + raise TypeError("Dot22Scalar requires float or complex args", a.dtype) + + sz = (x.type.shape[0], y.type.shape[1]) + outputs = [tensor(dtype=x.type.dtype, shape=sz)] + return Apply(self, [x, y, a], outputs) + + def perform(self, node, inp, out): + x, y, scalar = inp + (z,) = out + try: + z[0] = np.asarray(scalar * np.dot(x, y)) + except ValueError as e: + # The error raised by numpy has no shape information, we + # mean to add that + e.args = (*e.args, x.shape, y.shape) + raise + + def infer_shape(self, fgraph, node, input_shapes): + return [[input_shapes[0][0], input_shapes[1][1]]] + + setup_z_Nz_Sz = Dot22.setup_z_Nz_Sz + broadcast_xy = "" + + check_ab_double_or_float = """ + if ((PyArray_DESCR(%(_a)s)->type_num != NPY_DOUBLE) + && (PyArray_DESCR(%(_a)s)->type_num != NPY_FLOAT)) + {PyErr_SetString(PyExc_NotImplementedError, + "type(a) is not double or float"); %(fail)s;} + + """ + case_float_ab_constants = """ + #define REAL float + float a = (PyArray_DESCR(%(_a)s)->type_num == NPY_FLOAT) + ? (REAL)(((float*)PyArray_DATA(%(_a)s))[0]) + : (REAL)(((double*)PyArray_DATA(%(_a)s))[0]); + #undef REAL + float b = 0.0; + """ + + case_double_ab_constants = """ + #define REAL double + double a = (PyArray_DESCR(%(_a)s)->type_num == NPY_FLOAT) + ? (REAL)(((float*)PyArray_DATA(%(_a)s))[0]) + : (REAL)(((double*)PyArray_DATA(%(_a)s))[0]); + #undef REAL + double b = 0.0; + """ + + def c_code(self, node, name, inp, out, sub): + _x, _y, _a = inp + (_zout,) = out + if node.inputs[0].type.dtype.startswith("complex"): + raise MethodNotDefined(f"{self.__class__.__name__}.c_code") + if len(self.c_libraries()) <= 0: + raise NotImplementedError() + full_code = self.build_gemm_call() % dict(locals(), **sub) + return full_code + + def c_code_cache_version(self): + gv = self.build_gemm_version() + if gv: + return (2, *gv) + else: + return gv + + +_dot22scalar = Dot22Scalar() diff --git a/pytensor/tensor/blas/gemv.py b/pytensor/tensor/blas/gemv.py new file mode 100644 index 0000000000..d39fa89291 --- /dev/null +++ b/pytensor/tensor/blas/gemv.py @@ -0,0 +1,117 @@ +"""BLAS GEMV operation: matrix-vector multiply with accumulation. + +Computes: beta * y + alpha * dot(A, x) +""" + +import numpy as np +from scipy.linalg import get_blas_funcs + +from pytensor.graph.basic import Apply +from pytensor.graph.op import Op +from pytensor.tensor.basic import as_tensor_variable +from pytensor.tensor.blas._core import must_initialize_y_gemv +from pytensor.tensor.type import DenseTensorType + + +class Gemv(Op): + """ + expression is beta * y + alpha * A x + + A is matrix + x, y are vectors + alpha, beta are scalars + output is a vector that can be inplace on y + + """ + + __props__ = ("inplace",) + + def __init__(self, inplace): + self.inplace = inplace + if inplace: + self.destroy_map = {0: [0]} + + def __str__(self): + if self.inplace: + return f"{self.__class__.__name__}{{inplace}}" + else: + return f"{self.__class__.__name__}{{no_inplace}}" + + def make_node(self, y, alpha, A, x, beta): + y = as_tensor_variable(y) + x = as_tensor_variable(x) + A = as_tensor_variable(A) + alpha = as_tensor_variable(alpha) + beta = as_tensor_variable(beta) + if y.dtype != A.dtype or y.dtype != x.dtype: + raise TypeError( + "Gemv requires matching dtypes", (y.dtype, A.dtype, x.dtype) + ) + if A.ndim != 2: + raise TypeError("gemv requires matrix for A", A.type) + if x.ndim != 1: + raise TypeError("gemv requires vector for x", x.type) + if y.ndim != 1: + raise TypeError("gemv requires vector for y", y.type) + + inputs = [y, alpha, A, x, beta] + + if any(not isinstance(i.type, DenseTensorType) for i in inputs): + raise NotImplementedError("Only dense tensor types are supported") + + return Apply(self, inputs, [y.type()]) + + def perform(self, node, inputs, out_storage): + y, alpha, A, x, beta = inputs + if ( + y.shape[0] != 0 + and x.shape[0] != 0 + and y.dtype in {"float32", "float64", "complex64", "complex128"} + ): + gemv = get_blas_funcs("gemv", dtype=y.dtype) + + if A.shape[0] != y.shape[0] or A.shape[1] != x.shape[0]: + raise ValueError( + "Incompatible shapes for gemv " + f"(beta * y + alpha * dot(A, x)). y: {y.shape}, A: {A.shape}, x: {x.shape}" + ) + + if beta == 0 and must_initialize_y_gemv(): + # Most BLAS implementations of GEMV ignore y=nan when beta=0 + # PyTensor considers that the correct behavior, + # and even exploits it to avoid copying or initializing outputs. + # By deciding to exploit this, however, it becomes our responsibility + # to ensure the behavior even in the rare cases BLAS deviates, + # or users will get errors, even for graphs that had no nan to begin with. + y.fill(0) + + # Here I suppose that A is in c order. If we don't make it + # explicitly as fortran order, scipy 0.7.2 seam to create + # a copy in fortran order instead of just reshaping it + # and using the trans flag. + # If A is already in fortran order, make it in c order and using the + # trans flag don't seam to cause slowdown. + # out_storage[0][0] = gemv(alpha, A, x, beta, y, + # overwrite_y=self.inplace) + out_storage[0][0] = gemv( + alpha, A.T, x, beta, y, overwrite_y=self.inplace, trans=True + ) + else: + out = np.dot(A, x) + if alpha != 1: + out *= alpha + if beta != 0: + if beta != 1: + out += beta * y + else: + out += y + out_storage[0][0] = np.asarray(out, dtype=y.dtype) + + def infer_shape(self, fgraph, node, input_shapes): + return [input_shapes[0]] + + +gemv_no_inplace = Gemv(inplace=False) +gemv_inplace = Gemv(inplace=True) +# For the user interface. Opt will make them inplace later +gemv = gemv_no_inplace diff --git a/pytensor/tensor/blas/ger.py b/pytensor/tensor/blas/ger.py new file mode 100644 index 0000000000..bc333e9cc1 --- /dev/null +++ b/pytensor/tensor/blas/ger.py @@ -0,0 +1,77 @@ +from scipy.linalg import get_blas_funcs + +from pytensor.graph.basic import Apply +from pytensor.graph.op import Op +from pytensor.tensor.basic import as_tensor_variable +from pytensor.tensor.type import DenseTensorType + + +class Ger(Op): + """ + BLAS defines general rank-1 update GER as A <- A + alpha x y' + + for matrix A, scalar alpha, vectors x and y. + + This interface to GER allows non-destructive operation on A via the + `destructive` argument to the constructor. + + """ + + __props__ = ("destructive",) + + def __init__(self, destructive): + self.destructive = destructive + if destructive: + self.destroy_map = {0: [0]} + + def __str__(self): + if self.destructive: + return f"{self.__class__.__name__}{{destructive}}" + else: + return f"{self.__class__.__name__}{{non-destructive}}" + + def make_node(self, A, alpha, x, y): + A = as_tensor_variable(A) + y = as_tensor_variable(y) + x = as_tensor_variable(x) + alpha = as_tensor_variable(alpha) + if not (A.dtype == x.dtype == y.dtype == alpha.dtype): + raise TypeError( + "ger requires matching dtypes", (A.dtype, alpha.dtype, x.dtype, y.dtype) + ) + if alpha.ndim != 0: + raise TypeError("ger requires scalar alpha", alpha.type) + if A.ndim != 2: + raise TypeError("ger requires matrix for A", A.type) + if x.ndim != 1: + raise TypeError("ger requires vector for x", x.type) + if y.ndim != 1: + raise TypeError("ger requires vector for y", y.type) + + if x.dtype not in ("float32", "float64", "complex64", "complex128"): + raise TypeError("only float and complex types supported", x.dtype) + + inputs = [A, alpha, x, y] + if any(not isinstance(i.type, DenseTensorType) for i in inputs): + raise NotImplementedError("Only dense tensor types are supported") + + return Apply(self, inputs, [A.type()]) + + def perform(self, node, inputs, output_storage): + A, alpha, x, y = inputs + if A.size: + # GER doesn't handle zero-sized inputs + ger_func = get_blas_funcs("ger", dtype=A.dtype) + if A.flags["C_CONTIGUOUS"]: + # Work on transposed system to avoid copying + A = ger_func(alpha, y, x, a=A.T, overwrite_a=self.destructive).T + else: + A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive) + output_storage[0][0] = A + + def infer_shape(self, fgraph, node, input_shapes): + return [input_shapes[0]] + + +ger = Ger(destructive=False) +ger_destructive = Ger(destructive=True) From 984f0775f19348d4b5a8a193e7e2b57049c7e3fd Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 12 Apr 2026 18:06:02 -0500 Subject: [PATCH 02/13] Move blas c-code templates and codegen utils to tensor/blas/c_code and tensor/blas. --- pytensor/tensor/__init__.py | 1 - pytensor/tensor/basic.py | 4 +- pytensor/tensor/blas/__init__.py | 79 ++++++++++++------- pytensor/tensor/blas/batched.py | 4 +- pytensor/tensor/{ => blas}/blas_c.py | 11 +-- pytensor/tensor/{ => blas}/blas_headers.py | 0 .../{ => blas}/c_code/alt_blas_common.h | 0 .../{ => blas}/c_code/alt_blas_template.c | 0 pytensor/tensor/blas/gemm.py | 14 +--- pytensor/tensor/rewriting/blas_c.py | 2 +- tests/benchmarks/test_blas.py | 2 +- tests/tensor/rewriting/test_math.py | 2 +- tests/tensor/rewriting/test_subtensor_lift.py | 2 +- tests/tensor/test_blas_c.py | 2 +- tests/tensor/test_math.py | 8 +- 15 files changed, 73 insertions(+), 58 deletions(-) rename pytensor/tensor/{ => blas}/blas_c.py (99%) rename pytensor/tensor/{ => blas}/blas_headers.py (100%) rename pytensor/tensor/{ => blas}/c_code/alt_blas_common.h (100%) rename pytensor/tensor/{ => blas}/c_code/alt_blas_template.c (100%) diff --git a/pytensor/tensor/__init__.py b/pytensor/tensor/__init__.py index a814788dae..e731cfc6d2 100644 --- a/pytensor/tensor/__init__.py +++ b/pytensor/tensor/__init__.py @@ -105,7 +105,6 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int: # adds shared-variable constructors from pytensor.tensor import ( blas, - blas_c, sharedvar, xlogx, ) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index d1d19f41df..ed9e12a2de 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1774,9 +1774,9 @@ def do_constant_folding(self, fgraph, node): | pytensor.tensor.subtensor.AdvancedIncSubtensor1 | pytensor.tensor.subtensor.AdvancedIncSubtensor | pytensor.tensor.blas.Gemv - | pytensor.tensor.blas_c.CGemv + | pytensor.tensor.blas.CGemv | pytensor.tensor.blas.Ger - | pytensor.tensor.blas_c.CGer, + | pytensor.tensor.blas.CGer, ) ): # Ops that will work inplace on the Alloc. So if they diff --git a/pytensor/tensor/blas/__init__.py b/pytensor/tensor/blas/__init__.py index d5e77f51b1..49f205fdf7 100644 --- a/pytensor/tensor/blas/__init__.py +++ b/pytensor/tensor/blas/__init__.py @@ -82,8 +82,6 @@ from pytensor.tensor.blas._core import ( _ldflags, _logger, - blas_header_text, - blas_header_version, ldflags, must_initialize_y_gemv, view_roots, @@ -94,6 +92,26 @@ batched_dot, batched_tensordot, ) +from pytensor.tensor.blas.blas_c import ( + BaseBLAS, + CGemv, + CGer, + cgemv_inplace, + cgemv_no_inplace, + cger_inplace, + cger_no_inplace, +) +from pytensor.tensor.blas.blas_c import ( + must_initialize_y_gemv as must_initialize_y_gemv_c, +) +from pytensor.tensor.blas.blas_headers import ( + blas_header_text, + blas_header_version, + cblas_header_text, + detect_macos_sdot_bug, + mkl_threads_text, + openblas_threads_text, +) from pytensor.tensor.blas.gemm import ( Dot22, Dot22Scalar, @@ -110,36 +128,43 @@ __all__ = [ - # Core utilities - "view_roots", - "must_initialize_y_gemv", - "ldflags", + "BaseBLAS", + "BatchedDot", + "CGemv", + "CGer", + "Dot22", + "Dot22Scalar", + "Gemm", + "GemmRelated", + "Gemv", + "Ger", + "_batched_dot", + "_dot22", + "_dot22scalar", "_ldflags", "_logger", + "batched_dot", + "batched_tensordot", "blas_header_text", "blas_header_version", - # Gemv - "Gemv", - "gemv_no_inplace", - "gemv_inplace", + "cblas_header_text", + "cgemv_inplace", + "cgemv_no_inplace", + "cger_inplace", + "cger_no_inplace", + "detect_macos_sdot_bug", + "gemm", + "gemm_inplace", + "gemm_no_inplace", "gemv", - # Ger - "Ger", + "gemv_inplace", + "gemv_no_inplace", "ger", "ger_destructive", - # GemmRelated / Gemm / Dot22 / Dot22Scalar - "GemmRelated", - "Gemm", - "gemm_inplace", - "gemm_no_inplace", - "gemm", - "Dot22", - "_dot22", - "Dot22Scalar", - "_dot22scalar", - # BatchedDot - "BatchedDot", - "_batched_dot", - "batched_dot", - "batched_tensordot", + "ldflags", + "mkl_threads_text", + "must_initialize_y_gemv", + "must_initialize_y_gemv_c", + "openblas_threads_text", + "view_roots", ] diff --git a/pytensor/tensor/blas/batched.py b/pytensor/tensor/blas/batched.py index 59a4036555..3f232dd59f 100644 --- a/pytensor/tensor/blas/batched.py +++ b/pytensor/tensor/blas/batched.py @@ -14,8 +14,8 @@ from pytensor.graph.basic import Apply from pytensor.link.c.op import COp from pytensor.tensor.basic import as_tensor_variable, cast -from pytensor.tensor.blas._core import blas_header_text, ldflags -from pytensor.tensor.blas_headers import blas_header_version +from pytensor.tensor.blas._core import ldflags +from pytensor.tensor.blas.blas_headers import blas_header_text, blas_header_version from pytensor.tensor.math import dot, tensordot from pytensor.tensor.shape import specify_broadcastable from pytensor.tensor.type import DenseTensorType, tensor diff --git a/pytensor/tensor/blas_c.py b/pytensor/tensor/blas/blas_c.py similarity index 99% rename from pytensor/tensor/blas_c.py rename to pytensor/tensor/blas/blas_c.py index 83cd87796a..848a730cd6 100644 --- a/pytensor/tensor/blas_c.py +++ b/pytensor/tensor/blas/blas_c.py @@ -1,13 +1,10 @@ from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType from pytensor.scalar import bool as bool_t -from pytensor.tensor.blas import ( - Gemv, - Ger, - blas_header_text, - blas_header_version, - ldflags, -) +from pytensor.tensor.blas._core import ldflags +from pytensor.tensor.blas.blas_headers import blas_header_text, blas_header_version +from pytensor.tensor.blas.gemv import Gemv +from pytensor.tensor.blas.ger import Ger class BaseBLAS(COp): diff --git a/pytensor/tensor/blas_headers.py b/pytensor/tensor/blas/blas_headers.py similarity index 100% rename from pytensor/tensor/blas_headers.py rename to pytensor/tensor/blas/blas_headers.py diff --git a/pytensor/tensor/c_code/alt_blas_common.h b/pytensor/tensor/blas/c_code/alt_blas_common.h similarity index 100% rename from pytensor/tensor/c_code/alt_blas_common.h rename to pytensor/tensor/blas/c_code/alt_blas_common.h diff --git a/pytensor/tensor/c_code/alt_blas_template.c b/pytensor/tensor/blas/c_code/alt_blas_template.c similarity index 100% rename from pytensor/tensor/c_code/alt_blas_template.c rename to pytensor/tensor/blas/c_code/alt_blas_template.c diff --git a/pytensor/tensor/blas/gemm.py b/pytensor/tensor/blas/gemm.py index 3f2aec1fd1..b15a8fd37e 100644 --- a/pytensor/tensor/blas/gemm.py +++ b/pytensor/tensor/blas/gemm.py @@ -1,10 +1,3 @@ -"""BLAS GEMM family: Gemm, Dot22, Dot22Scalar, and the GemmRelated base class. - -Gemm computes: b*z + a*dot(x,y) -Dot22 computes: dot(x, y) (matrix-matrix, BLAS-accelerated) -Dot22Scalar computes: scalar * dot(x, y) -""" - import numpy as np import pytensor.scalar @@ -16,11 +9,13 @@ from pytensor.scalar import bool as bool_t from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.blas._core import ( - blas_header_text, - blas_header_version, ldflags, view_roots, ) +from pytensor.tensor.blas.blas_headers import ( + blas_header_text, + blas_header_version, +) from pytensor.tensor.type import DenseTensorType, tensor @@ -28,7 +23,6 @@ class GemmRelated(COp): """Base class for Gemm and Dot22. This class provides a kind of templated gemm Op. - """ __props__: tuple[str, ...] = () diff --git a/pytensor/tensor/rewriting/blas_c.py b/pytensor/tensor/rewriting/blas_c.py index d4220d6d7a..37d2d1dce7 100644 --- a/pytensor/tensor/rewriting/blas_c.py +++ b/pytensor/tensor/rewriting/blas_c.py @@ -2,7 +2,7 @@ from pytensor.graph.rewriting.basic import dfs_rewriter from pytensor.tensor import basic as ptb from pytensor.tensor.blas import gemv_inplace, gemv_no_inplace, ger, ger_destructive -from pytensor.tensor.blas_c import ( +from pytensor.tensor.blas.blas_c import ( CGemv, CGer, cgemv_inplace, diff --git a/tests/benchmarks/test_blas.py b/tests/benchmarks/test_blas.py index 5cd53d309f..17ae5573ce 100644 --- a/tests/benchmarks/test_blas.py +++ b/tests/benchmarks/test_blas.py @@ -3,7 +3,7 @@ from pytensor import In, function from pytensor.tensor import dot, empty, matrix, outer, scalar, tensor, vector -from pytensor.tensor.blas_c import CGemv +from pytensor.tensor.blas.blas_c import CGemv @pytest.mark.parametrize("dtype", ("float64", "float32", "mixed")) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 6d0b0d978c..1d87a1c661 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -33,7 +33,7 @@ from pytensor.scalar import PolyGamma, Psi, TriGamma from pytensor.tensor.basic import Alloc, constant, join, second, switch from pytensor.tensor.blas import Dot22, Gemv -from pytensor.tensor.blas_c import CGemv +from pytensor.tensor.blas.blas_c import CGemv from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.linalg.constructors import BlockDiagonal diff --git a/tests/tensor/rewriting/test_subtensor_lift.py b/tests/tensor/rewriting/test_subtensor_lift.py index 6dfa5b82f0..1f9d857adf 100644 --- a/tests/tensor/rewriting/test_subtensor_lift.py +++ b/tests/tensor/rewriting/test_subtensor_lift.py @@ -37,7 +37,7 @@ ) from pytensor.tensor.basic import MakeVector, concatenate, expand_dims, make_vector from pytensor.tensor.blas import Dot22, Gemv -from pytensor.tensor.blas_c import CGemv +from pytensor.tensor.blas.blas_c import CGemv from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import Dot diff --git a/tests/tensor/test_blas_c.py b/tests/tensor/test_blas_c.py index 9184b47020..5b3da5cff9 100644 --- a/tests/tensor/test_blas_c.py +++ b/tests/tensor/test_blas_c.py @@ -8,7 +8,7 @@ from pytensor.compile import get_mode from pytensor.tensor.basic import AllocEmpty from pytensor.tensor.blas import Ger -from pytensor.tensor.blas_c import CGemv, CGer, must_initialize_y_gemv +from pytensor.tensor.blas.blas_c import CGemv, CGer, must_initialize_y_gemv from pytensor.tensor.type import ( dmatrix, dscalar, diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 2516cb6ff4..baf172818b 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -27,7 +27,7 @@ from pytensor.link.numba import NumbaLinker from pytensor.printing import pprint from pytensor.raise_op import Assert -from pytensor.tensor import blas, blas_c +from pytensor.tensor import blas from pytensor.tensor.basic import ( as_tensor_variable, constant, @@ -2810,7 +2810,7 @@ def test_Dot(self): [advec, bdvec], [dot(advec, bdvec)], [advec_val, bdvec_val], - (Dot, blas.Dot22, blas.Gemv, blas_c.CGemv), + (Dot, blas.Dot22, blas.Gemv, blas.CGemv), ) # mat/mat @@ -2831,7 +2831,7 @@ def test_Dot(self): [advec, bdmat], [dot(advec, bdmat)], [advec_val, bdmat_val], - (Dot, blas.Dot22, blas.Gemv, blas_c.CGemv), + (Dot, blas.Dot22, blas.Gemv, blas.CGemv), ) # mat/vec @@ -2840,7 +2840,7 @@ def test_Dot(self): [admat, bdvec], [dot(admat, bdvec)], [admat_val, bdvec_val], - (Dot, blas.Dot22, blas.Gemv, blas_c.CGemv), + (Dot, blas.Dot22, blas.Gemv, blas.CGemv), ) From cd41fe6b94be7e015e6a34517868398eb175cf9f Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 12 Apr 2026 18:25:46 -0500 Subject: [PATCH 03/13] Move static C-code strings to header files for readability --- pytensor/tensor/blas/__init__.py | 4 - pytensor/tensor/blas/c_code/fortran_blas.h | 214 ++++++++++++++++++ pytensor/tensor/blas/c_code/mkl_threads.h | 33 +++ .../tensor/blas/c_code/openblas_threads.h | 16 ++ pytensor/tensor/blas/gemm.py | 1 - 5 files changed, 263 insertions(+), 5 deletions(-) create mode 100644 pytensor/tensor/blas/c_code/fortran_blas.h create mode 100644 pytensor/tensor/blas/c_code/mkl_threads.h create mode 100644 pytensor/tensor/blas/c_code/openblas_threads.h diff --git a/pytensor/tensor/blas/__init__.py b/pytensor/tensor/blas/__init__.py index 49f205fdf7..750218c7e8 100644 --- a/pytensor/tensor/blas/__init__.py +++ b/pytensor/tensor/blas/__init__.py @@ -107,8 +107,6 @@ from pytensor.tensor.blas.blas_headers import ( blas_header_text, blas_header_version, - cblas_header_text, - detect_macos_sdot_bug, mkl_threads_text, openblas_threads_text, ) @@ -147,12 +145,10 @@ "batched_tensordot", "blas_header_text", "blas_header_version", - "cblas_header_text", "cgemv_inplace", "cgemv_no_inplace", "cger_inplace", "cger_no_inplace", - "detect_macos_sdot_bug", "gemm", "gemm_inplace", "gemm_no_inplace", diff --git a/pytensor/tensor/blas/c_code/fortran_blas.h b/pytensor/tensor/blas/c_code/fortran_blas.h new file mode 100644 index 0000000000..b1b597f33d --- /dev/null +++ b/pytensor/tensor/blas/c_code/fortran_blas.h @@ -0,0 +1,214 @@ +/* + * Fortran BLAS interface declarations for PyTensor. + * + * These are the extern "C" declarations for the Fortran BLAS routines + * (with trailing underscore convention). Used by GemmRelated, CGemv, CGer, etc. + */ + +#ifndef PYTENSOR_FORTRAN_BLAS_H +#define PYTENSOR_FORTRAN_BLAS_H + +extern "C" +{ + + void xerbla_(char*, void *); + +/***********/ +/* Level 1 */ +/***********/ + +/* Single Precision */ + + void srot_(const int*, float *, const int*, float *, const int*, const float *, const float *); + void srotg_(float *,float *,float *,float *); + void srotm_( const int*, float *, const int*, float *, const int*, const float *); + void srotmg_(float *,float *,float *,const float *, float *); + void sswap_( const int*, float *, const int*, float *, const int*); + void scopy_( const int*, const float *, const int*, float *, const int*); + void saxpy_( const int*, const float *, const float *, const int*, float *, const int*); + float sdot_(const int*, const float *, const int*, const float *, const int*); + void sdot_sub_(const int*, const float *, const int*, const float *, const int*, float *); + void sdsdot_sub_( const int*, const float *, const float *, const int*, const float *, const int*, float *); + void sscal_( const int*, const float *, float *, const int*); + void snrm2_sub_( const int*, const float *, const int*, float *); + void sasum_sub_( const int*, const float *, const int*, float *); + void isamax_sub_( const int*, const float * , const int*, const int*); + +/* Double Precision */ + + void drot_(const int*, double *, const int*, double *, const int*, const double *, const double *); + void drotg_(double *,double *,double *,double *); + void drotm_( const int*, double *, const int*, double *, const int*, const double *); + void drotmg_(double *,double *,double *,const double *, double *); + void dswap_( const int*, double *, const int*, double *, const int*); + void dcopy_( const int*, const double *, const int*, double *, const int*); + void daxpy_( const int*, const double *, const double *, const int*, double *, const int*); + void dswap_( const int*, double *, const int*, double *, const int*); + double ddot_(const int*, const double *, const int*, const double *, const int*); + void dsdot_sub_(const int*, const float *, const int*, const float *, const int*, double *); + void ddot_sub_( const int*, const double *, const int*, const double *, const int*, double *); + void dscal_( const int*, const double *, double *, const int*); + void dnrm2_sub_( const int*, const double *, const int*, double *); + void dasum_sub_( const int*, const double *, const int*, double *); + void idamax_sub_( const int*, const double * , const int*, const int*); + +/* Single Complex Precision */ + + void cswap_( const int*, void *, const int*, void *, const int*); + void ccopy_( const int*, const void *, const int*, void *, const int*); + void caxpy_( const int*, const void *, const void *, const int*, void *, const int*); + void cswap_( const int*, void *, const int*, void *, const int*); + void cdotc_sub_( const int*, const void *, const int*, const void *, const int*, void *); + void cdotu_sub_( const int*, const void *, const int*, const void *, const int*, void *); + void cscal_( const int*, const void *, void *, const int*); + void icamax_sub_( const int*, const void *, const int*, const int*); + void csscal_( const int*, const float *, void *, const int*); + void scnrm2_sub_( const int*, const void *, const int*, float *); + void scasum_sub_( const int*, const void *, const int*, float *); + +/* Double Complex Precision */ + + void zswap_( const int*, void *, const int*, void *, const int*); + void zcopy_( const int*, const void *, const int*, void *, const int*); + void zaxpy_( const int*, const void *, const void *, const int*, void *, const int*); + void zswap_( const int*, void *, const int*, void *, const int*); + void zdotc_sub_( const int*, const void *, const int*, const void *, const int*, void *); + void zdotu_sub_( const int*, const void *, const int*, const void *, const int*, void *); + void zdscal_( const int*, const double *, void *, const int*); + void zscal_( const int*, const void *, void *, const int*); + void dznrm2_sub_( const int*, const void *, const int*, double *); + void dzasum_sub_( const int*, const void *, const int*, double *); + void izamax_sub_( const int*, const void *, const int*, const int*); + +/***********/ +/* Level 2 */ +/***********/ + +/* Single Precision */ + + void sgemv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); + void sgbmv_(char*, const int*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); + void ssymv_(char*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); + void ssbmv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); + void sspmv_(char*, const int*, const float *, const float *, const float *, const int*, const float *, float *, const int*); + void strmv_( char*, char*, char*, const int*, const float *, const int*, float *, const int*); + void stbmv_( char*, char*, char*, const int*, const int*, const float *, const int*, float *, const int*); + void strsv_( char*, char*, char*, const int*, const float *, const int*, float *, const int*); + void stbsv_( char*, char*, char*, const int*, const int*, const float *, const int*, float *, const int*); + void stpmv_( char*, char*, char*, const int*, const float *, float *, const int*); + void stpsv_( char*, char*, char*, const int*, const float *, float *, const int*); + void sger_( const int*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*); + void ssyr_(char*, const int*, const float *, const float *, const int*, float *, const int*); + void sspr_(char*, const int*, const float *, const float *, const int*, float *); + void sspr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *); + void ssyr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*); + +/* Double Precision */ + + void dgemv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); + void dgbmv_(char*, const int*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); + void dsymv_(char*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); + void dsbmv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); + void dspmv_(char*, const int*, const double *, const double *, const double *, const int*, const double *, double *, const int*); + void dtrmv_( char*, char*, char*, const int*, const double *, const int*, double *, const int*); + void dtbmv_( char*, char*, char*, const int*, const int*, const double *, const int*, double *, const int*); + void dtrsv_( char*, char*, char*, const int*, const double *, const int*, double *, const int*); + void dtbsv_( char*, char*, char*, const int*, const int*, const double *, const int*, double *, const int*); + void dtpmv_( char*, char*, char*, const int*, const double *, double *, const int*); + void dtpsv_( char*, char*, char*, const int*, const double *, double *, const int*); + void dger_( const int*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*); + void dsyr_(char*, const int*, const double *, const double *, const int*, double *, const int*); + void dspr_(char*, const int*, const double *, const double *, const int*, double *); + void dspr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *); + void dsyr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*); + +/* Single Complex Precision */ + + void cgemv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*); + void cgbmv_(char*, const int*, const int*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*); + void chemv_(char*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*); + void chbmv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*); + void chpmv_(char*, const int*, const void *, const void *, const void *, const int*, const void *, void *, const int*); + void ctrmv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*); + void ctbmv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*); + void ctpmv_( char*, char*, char*, const int*, const void *, void *, const int*); + void ctrsv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*); + void ctbsv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*); + void ctpsv_( char*, char*, char*, const int*, const void *, void *,const int*); + void cgerc_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*); + void cgeru_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*); + void cher_(char*, const int*, const float *, const void *, const int*, void *, const int*); + void cher2_(char*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*); + void chpr_(char*, const int*, const float *, const void *, const int*, void *); + void chpr2_(char*, const int*, const float *, const void *, const int*, const void *, const int*, void *); + +/* Double Complex Precision */ + + void zgemv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*); + void zgbmv_(char*, const int*, const int*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*); + void zhemv_(char*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*); + void zhbmv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*); + void zhpmv_(char*, const int*, const void *, const void *, const void *, const int*, const void *, void *, const int*); + void ztrmv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*); + void ztbmv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*); + void ztpmv_( char*, char*, char*, const int*, const void *, void *, const int*); + void ztrsv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*); + void ztbsv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*); + void ztpsv_( char*, char*, char*, const int*, const void *, void *,const int*); + void zgerc_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*); + void zgeru_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*); + void zher_(char*, const int*, const double *, const void *, const int*, void *, const int*); + void zher2_(char*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*); + void zhpr_(char*, const int*, const double *, const void *, const int*, void *); + void zhpr2_(char*, const int*, const double *, const void *, const int*, const void *, const int*, void *); + +/***********/ +/* Level 3 */ +/***********/ + +/* Single Precision */ + + void sgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); + void ssymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); + void ssyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*); + void ssyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); + void strmm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*); + void strsm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*); + +/* Double Precision */ + + void dgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); + void dsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); + void dsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*); + void dsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); + void dtrmm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*); + void dtrsm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*); + +/* Single Complex Precision */ + + void cgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); + void csymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); + void chemm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); + void csyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*); + void cherk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*); + void csyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); + void cher2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); + void ctrmm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*); + void ctrsm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*); + +/* Double Complex Precision */ + + void zgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); + void zsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); + void zhemm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); + void zsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*); + void zherk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*); + void zsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); + void zher2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); + void ztrmm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*); + void ztrsm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*); + +} + +#endif /* PYTENSOR_FORTRAN_BLAS_H */ + diff --git a/pytensor/tensor/blas/c_code/mkl_threads.h b/pytensor/tensor/blas/c_code/mkl_threads.h new file mode 100644 index 0000000000..cda2a40d1e --- /dev/null +++ b/pytensor/tensor/blas/c_code/mkl_threads.h @@ -0,0 +1,33 @@ +/* + * MKL threads interface declarations for PyTensor. + */ + +#ifndef PYTENSOR_MKL_THREADS_H +#define PYTENSOR_MKL_THREADS_H + +extern "C" +{ + int MKL_Set_Num_Threads_Local(int); + #define mkl_set_num_threads_local MKL_Set_Num_Threads_Local + + void MKL_Set_Num_Threads(int); + #define mkl_set_num_threads MKL_Set_Num_Threads + + int MKL_Get_Max_Threads(void); + #define mkl_get_max_threads MKL_Get_Max_Threads + + int MKL_Domain_Set_Num_Threads(int, int); + #define mkl_domain_set_num_threads MKL_Domain_Set_Num_Threads + + int MKL_Domain_Get_Max_Threads(int); + #define mkl_domain_get_max_threads MKL_Domain_Get_Max_Threads + + void MKL_Set_Dynamic(int); + #define mkl_set_dynamic MKL_Set_Dynamic + + int MKL_Get_Dynamic(void); + #define mkl_get_dynamic MKL_Get_Dynamic +} + +#endif /* PYTENSOR_MKL_THREADS_H */ + diff --git a/pytensor/tensor/blas/c_code/openblas_threads.h b/pytensor/tensor/blas/c_code/openblas_threads.h new file mode 100644 index 0000000000..3a01c44857 --- /dev/null +++ b/pytensor/tensor/blas/c_code/openblas_threads.h @@ -0,0 +1,16 @@ +/* + * OpenBLAS threads interface declarations for PyTensor. + */ + +#ifndef PYTENSOR_OPENBLAS_THREADS_H +#define PYTENSOR_OPENBLAS_THREADS_H + +extern "C" +{ + void openblas_set_num_threads(int); + void goto_set_num_threads(int); + int openblas_get_num_threads(void); +} + +#endif /* PYTENSOR_OPENBLAS_THREADS_H */ + diff --git a/pytensor/tensor/blas/gemm.py b/pytensor/tensor/blas/gemm.py index b15a8fd37e..02cdb9785a 100644 --- a/pytensor/tensor/blas/gemm.py +++ b/pytensor/tensor/blas/gemm.py @@ -28,7 +28,6 @@ class GemmRelated(COp): __props__: tuple[str, ...] = () def c_support_code(self, **kwargs): - # return cblas_header_text() mod_str = """ #ifndef MOD #define MOD % From 565bb56bedab87c47721c19913f09c1b6b41821c Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 12 Apr 2026 18:31:13 -0500 Subject: [PATCH 04/13] Remove macOS 10.6 sdot bug workaround --- pytensor/tensor/blas/blas_headers.py | 182 +-------------------------- 1 file changed, 2 insertions(+), 180 deletions(-) diff --git a/pytensor/tensor/blas/blas_headers.py b/pytensor/tensor/blas/blas_headers.py index 5d49b70ec4..9677593159 100644 --- a/pytensor/tensor/blas/blas_headers.py +++ b/pytensor/tensor/blas/blas_headers.py @@ -6,155 +6,14 @@ """ import logging -import os -import sys -import textwrap from pathlib import Path from pytensor.configdefaults import config -from pytensor.link.c.cmodule import GCC_compiler _logger = logging.getLogger("pytensor.tensor.blas") -def detect_macos_sdot_bug(): - """ - Try to detect a bug in the default BLAS in MacOS. - - The problem in PyTensor has been reported in gh-1240, - the underlying bug has been confirmed in - http://www.macresearch.org/lapackblas-fortran-106#comment-17227. - - This function tries to compile code triggering that bug, - and, if necessary, an attempted fix. - - Three attributes of this function will be set: - - detect_macos_sdot_bug.tested will be set to True - when this function is called. - - detect_macos_sdot_bug.present will be set to True if the bug is - detected. Its value is returned by the function - - detect_macos_sdot_bug.fix_works will be set to True if the fix was - attempted, and succeeded. - - """ - _logger.debug("Starting detection of bug in Mac OS BLAS sdot_ routine") - if detect_macos_sdot_bug.tested: - return detect_macos_sdot_bug.present - - if sys.platform != "darwin" or not config.blas__ldflags: - _logger.info("Not Mac OS, no sdot_ bug") - detect_macos_sdot_bug.tested = True - return False - - # This code will return -1 if the dot product did not return - # the right value (30.). - flags = config.blas__ldflags.split() - for f in flags: - # Library directories should also be added as rpath, - # so that they can be loaded even if the environment - # variable LD_LIBRARY_PATH does not contain them - lib_path = os.environ.get("DYLD_FALLBACK_LIBRARY_PATH", "").split(":") - if f.startswith("-L"): - flags.append("-Wl,-rpath," + f[2:]) - # also append those paths to DYLD_FALLBACK_LIBRARY_PATH to - # support libraries that have the wrong install_name - # (such as MKL on canopy installs) - if f[2:] not in lib_path: - lib_path.append(f[2:]) - # this goes into the python process environment that is - # inherited by subprocesses/used by dyld when loading new objects - os.environ["DYLD_FALLBACK_LIBRARY_PATH"] = ":".join(lib_path) - - test_code = textwrap.dedent( - """\ - extern "C" float sdot_(int*, float*, int*, float*, int*); - int main(int argc, char** argv) - { - int Nx = 5; - int Sx = 1; - float x[5] = {0, 1, 2, 3, 4}; - float r = sdot_(&Nx, x, &Sx, x, &Sx); - - if ((r - 30.f) > 1e-6 || (r - 30.f) < -1e-6) - { - return -1; - } - return 0; - } - """ - ) - - _logger.debug("Trying to compile and run test case.") - compilation_ok, run_ok = GCC_compiler.try_compile_tmp( - test_code, tmp_prefix="detect_macos_sdot_bug_", flags=flags, try_run=True - ) - detect_macos_sdot_bug.tested = True - - # If compilation failed, we consider there is a bug, - # and the fix does not work - if not compilation_ok: - _logger.info("Could not compile test case for sdot_.") - detect_macos_sdot_bug.present = True - return True - - if run_ok: - _logger.info("The sdot_ bug is not present on this system.") - detect_macos_sdot_bug.present = False - return False - - # Else, the bug is detected. - _logger.info("The sdot_ bug is present on this system.") - detect_macos_sdot_bug.present = True - - # Then, try a simple fix - test_fix_code = textwrap.dedent( - """\ - extern "C" float cblas_sdot(int, float*, int, float*, int); - static float sdot_(int* Nx, float* x, int* Sx, float* y, int* Sy) - { - return cblas_sdot(*Nx, x, *Sx, y, *Sy); - } - - int main(int argc, char** argv) - { - int Nx = 5; - int Sx = 1; - float x[5] = {0, 1, 2, 3, 4}; - float r = sdot_(&Nx, x, &Sx, x, &Sx); - - if ((r - 30.f) > 1e-6 || (r - 30.f) < -1e-6) - { - return -1; - } - return 0; - } - """ - ) - - _logger.debug("Trying to compile and run tentative workaround.") - compilation_fix_ok, run_fix_ok = GCC_compiler.try_compile_tmp( - test_fix_code, - tmp_prefix="detect_macos_sdot_bug_testfix_", - flags=flags, - try_run=True, - ) - - _logger.info( - "Status of tentative fix -- compilation OK: %s, works: %s", - compilation_fix_ok, - run_fix_ok, - ) - detect_macos_sdot_bug.fix_works = run_fix_ok - - return detect_macos_sdot_bug.present - - -detect_macos_sdot_bug.tested = False -detect_macos_sdot_bug.present = False -detect_macos_sdot_bug.fix_works = False - - def cblas_header_text(): """C header for the cblas interface.""" @@ -977,34 +836,6 @@ def blas_header_text(): } """ - if detect_macos_sdot_bug(): - if detect_macos_sdot_bug.fix_works: - header += textwrap.dedent( - """\ - extern "C" float cblas_sdot(int, float*, int, float*, int); - static float sdot_(int* Nx, float* x, int* Sx, float* y, int* Sy) - { - return cblas_sdot(*Nx, x, *Sx, y, *Sy); - } - """ - ) - else: - # Make sure the buggy version of sdot_ is never used - header += textwrap.dedent( - """\ - static float sdot_(int* Nx, float* x, int* Sx, float* y, int* Sy) - { - fprintf(stderr, - "FATAL: The implementation of BLAS SDOT " - "routine in your system has a bug that " - "makes it return wrong results.\\n" - "You can work around this bug by using a " - "different BLAS library, or disabling BLAS\\n"); - assert(0); - } - """ - ) - return header + blas_code @@ -1052,17 +883,8 @@ def openblas_threads_text(): def blas_header_version(): - # Version for the base header - version = (10,) - if detect_macos_sdot_bug(): - if detect_macos_sdot_bug.fix_works: - # Version with fix - version += (1,) - else: - # Version with error - version += (2,) - - return version + # Version 11: Removed obsolete macOS 10.6 sdot bug workaround + return (11,) def ____gemm_code(check_ab, a_init, b_init): From e17164ed141b41f738654bc0dd9a36b49dacc7e8 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 12 Apr 2026 18:34:21 -0500 Subject: [PATCH 05/13] Remove unused blas_c_headers --- pytensor/tensor/blas/blas_headers.py | 582 --------------------------- 1 file changed, 582 deletions(-) diff --git a/pytensor/tensor/blas/blas_headers.py b/pytensor/tensor/blas/blas_headers.py index 9677593159..affd937ee7 100644 --- a/pytensor/tensor/blas/blas_headers.py +++ b/pytensor/tensor/blas/blas_headers.py @@ -14,588 +14,6 @@ _logger = logging.getLogger("pytensor.tensor.blas") -def cblas_header_text(): - """C header for the cblas interface.""" - - return """ - //#include - - #undef __BEGIN_DECLS - #undef __END_DECLS - #ifdef __cplusplus - #define __BEGIN_DECLS extern "C" { - #define __END_DECLS } - #else - #define __BEGIN_DECLS /* empty */ - #define __END_DECLS /* empty */ - #endif - - __BEGIN_DECLS - - #define MOD % - - /* - * Enumerated and derived types - */ - #define CBLAS_INDEX size_t /* this may vary between platforms */ - - enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102}; - enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113}; - enum CBLAS_UPLO {CblasUpper=121, CblasLower=122}; - enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132}; - enum CBLAS_SIDE {CblasLeft=141, CblasRight=142}; - - float cblas_sdsdot(const int N, const float alpha, const float *X, - const int incX, const float *Y, const int incY); - double cblas_dsdot(const int N, const float *X, const int incX, const float *Y, - const int incY); - float cblas_sdot(const int N, const float *X, const int incX, - const float *Y, const int incY); - double cblas_ddot(const int N, const double *X, const int incX, - const double *Y, const int incY); - - /* - * Functions having prefixes Z and C only - */ - void cblas_cdotu_sub(const int N, const void *X, const int incX, - const void *Y, const int incY, void *dotu); - void cblas_cdotc_sub(const int N, const void *X, const int incX, - const void *Y, const int incY, void *dotc); - - void cblas_zdotu_sub(const int N, const void *X, const int incX, - const void *Y, const int incY, void *dotu); - void cblas_zdotc_sub(const int N, const void *X, const int incX, - const void *Y, const int incY, void *dotc); - - - /* - * Functions having prefixes S D SC DZ - */ - float cblas_snrm2(const int N, const float *X, const int incX); - float cblas_sasum(const int N, const float *X, const int incX); - - double cblas_dnrm2(const int N, const double *X, const int incX); - double cblas_dasum(const int N, const double *X, const int incX); - - float cblas_scnrm2(const int N, const void *X, const int incX); - float cblas_scasum(const int N, const void *X, const int incX); - - double cblas_dznrm2(const int N, const void *X, const int incX); - double cblas_dzasum(const int N, const void *X, const int incX); - - - /* - * Functions having standard 4 prefixes (S D C Z) - */ - CBLAS_INDEX cblas_isamax(const int N, const float *X, const int incX); - CBLAS_INDEX cblas_idamax(const int N, const double *X, const int incX); - CBLAS_INDEX cblas_icamax(const int N, const void *X, const int incX); - CBLAS_INDEX cblas_izamax(const int N, const void *X, const int incX); - - /* - * =========================================================================== - * Prototypes for level 1 BLAS routines - * =========================================================================== - */ - - /* - * Routines with standard 4 prefixes (s, d, c, z) - */ - void cblas_sswap(const int N, float *X, const int incX, - float *Y, const int incY); - void cblas_scopy(const int N, const float *X, const int incX, - float *Y, const int incY); - void cblas_saxpy(const int N, const float alpha, const float *X, - const int incX, float *Y, const int incY); - - void cblas_dswap(const int N, double *X, const int incX, - double *Y, const int incY); - void cblas_dcopy(const int N, const double *X, const int incX, - double *Y, const int incY); - void cblas_daxpy(const int N, const double alpha, const double *X, - const int incX, double *Y, const int incY); - - void cblas_cswap(const int N, void *X, const int incX, - void *Y, const int incY); - void cblas_ccopy(const int N, const void *X, const int incX, - void *Y, const int incY); - void cblas_caxpy(const int N, const void *alpha, const void *X, - const int incX, void *Y, const int incY); - - void cblas_zswap(const int N, void *X, const int incX, - void *Y, const int incY); - void cblas_zcopy(const int N, const void *X, const int incX, - void *Y, const int incY); - void cblas_zaxpy(const int N, const void *alpha, const void *X, - const int incX, void *Y, const int incY); - - - /* - * Routines with S and D prefix only - */ - void cblas_srotg(float *a, float *b, float *c, float *s); - void cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P); - void cblas_srot(const int N, float *X, const int incX, - float *Y, const int incY, const float c, const float s); - void cblas_srotm(const int N, float *X, const int incX, - float *Y, const int incY, const float *P); - - void cblas_drotg(double *a, double *b, double *c, double *s); - void cblas_drotmg(double *d1, double *d2, double *b1, const double b2, double *P); - void cblas_drot(const int N, double *X, const int incX, - double *Y, const int incY, const double c, const double s); - void cblas_drotm(const int N, double *X, const int incX, - double *Y, const int incY, const double *P); - - - /* - * Routines with S D C Z CS and ZD prefixes - */ - void cblas_sscal(const int N, const float alpha, float *X, const int incX); - void cblas_dscal(const int N, const double alpha, double *X, const int incX); - void cblas_cscal(const int N, const void *alpha, void *X, const int incX); - void cblas_zscal(const int N, const void *alpha, void *X, const int incX); - void cblas_csscal(const int N, const float alpha, void *X, const int incX); - void cblas_zdscal(const int N, const double alpha, void *X, const int incX); - - /* - * =========================================================================== - * Prototypes for level 2 BLAS - * =========================================================================== - */ - - /* - * Routines with standard 4 prefixes (S, D, C, Z) - */ - void cblas_sgemv(const enum CBLAS_ORDER order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const float alpha, const float *A, const int lda, - const float *X, const int incX, const float beta, - float *Y, const int incY); - void cblas_sgbmv(const enum CBLAS_ORDER order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const int KL, const int KU, const float alpha, - const float *A, const int lda, const float *X, - const int incX, const float beta, float *Y, const int incY); - void cblas_strmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const float *A, const int lda, - float *X, const int incX); - void cblas_stbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const float *A, const int lda, - float *X, const int incX); - void cblas_stpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const float *Ap, float *X, const int incX); - void cblas_strsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const float *A, const int lda, float *X, - const int incX); - void cblas_stbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const float *A, const int lda, - float *X, const int incX); - void cblas_stpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const float *Ap, float *X, const int incX); - - void cblas_dgemv(const enum CBLAS_ORDER order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const double alpha, const double *A, const int lda, - const double *X, const int incX, const double beta, - double *Y, const int incY); - void cblas_dgbmv(const enum CBLAS_ORDER order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const int KL, const int KU, const double alpha, - const double *A, const int lda, const double *X, - const int incX, const double beta, double *Y, const int incY); - void cblas_dtrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const double *A, const int lda, - double *X, const int incX); - void cblas_dtbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const double *A, const int lda, - double *X, const int incX); - void cblas_dtpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const double *Ap, double *X, const int incX); - void cblas_dtrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const double *A, const int lda, double *X, - const int incX); - void cblas_dtbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const double *A, const int lda, - double *X, const int incX); - void cblas_dtpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const double *Ap, double *X, const int incX); - - void cblas_cgemv(const enum CBLAS_ORDER order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const void *alpha, const void *A, const int lda, - const void *X, const int incX, const void *beta, - void *Y, const int incY); - void cblas_cgbmv(const enum CBLAS_ORDER order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const int KL, const int KU, const void *alpha, - const void *A, const int lda, const void *X, - const int incX, const void *beta, void *Y, const int incY); - void cblas_ctrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *A, const int lda, - void *X, const int incX); - void cblas_ctbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const void *A, const int lda, - void *X, const int incX); - void cblas_ctpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *Ap, void *X, const int incX); - void cblas_ctrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *A, const int lda, void *X, - const int incX); - void cblas_ctbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const void *A, const int lda, - void *X, const int incX); - void cblas_ctpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *Ap, void *X, const int incX); - - void cblas_zgemv(const enum CBLAS_ORDER order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const void *alpha, const void *A, const int lda, - const void *X, const int incX, const void *beta, - void *Y, const int incY); - void cblas_zgbmv(const enum CBLAS_ORDER order, - const enum CBLAS_TRANSPOSE TransA, const int M, const int N, - const int KL, const int KU, const void *alpha, - const void *A, const int lda, const void *X, - const int incX, const void *beta, void *Y, const int incY); - void cblas_ztrmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *A, const int lda, - void *X, const int incX); - void cblas_ztbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const void *A, const int lda, - void *X, const int incX); - void cblas_ztpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *Ap, void *X, const int incX); - void cblas_ztrsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *A, const int lda, void *X, - const int incX); - void cblas_ztbsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const int K, const void *A, const int lda, - void *X, const int incX); - void cblas_ztpsv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, - const int N, const void *Ap, void *X, const int incX); - - - /* - * Routines with S and D prefixes only - */ - void cblas_ssymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const float *A, - const int lda, const float *X, const int incX, - const float beta, float *Y, const int incY); - void cblas_ssbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const int K, const float alpha, const float *A, - const int lda, const float *X, const int incX, - const float beta, float *Y, const int incY); - void cblas_sspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const float *Ap, - const float *X, const int incX, - const float beta, float *Y, const int incY); - void cblas_sger(const enum CBLAS_ORDER order, const int M, const int N, - const float alpha, const float *X, const int incX, - const float *Y, const int incY, float *A, const int lda); - void cblas_ssyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const float *X, - const int incX, float *A, const int lda); - void cblas_sspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const float *X, - const int incX, float *Ap); - void cblas_ssyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const float *X, - const int incX, const float *Y, const int incY, float *A, - const int lda); - void cblas_sspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const float *X, - const int incX, const float *Y, const int incY, float *A); - - void cblas_dsymv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const double *A, - const int lda, const double *X, const int incX, - const double beta, double *Y, const int incY); - void cblas_dsbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const int K, const double alpha, const double *A, - const int lda, const double *X, const int incX, - const double beta, double *Y, const int incY); - void cblas_dspmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const double *Ap, - const double *X, const int incX, - const double beta, double *Y, const int incY); - void cblas_dger(const enum CBLAS_ORDER order, const int M, const int N, - const double alpha, const double *X, const int incX, - const double *Y, const int incY, double *A, const int lda); - void cblas_dsyr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const double *X, - const int incX, double *A, const int lda); - void cblas_dspr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const double *X, - const int incX, double *Ap); - void cblas_dsyr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const double *X, - const int incX, const double *Y, const int incY, double *A, - const int lda); - void cblas_dspr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const double *X, - const int incX, const double *Y, const int incY, double *A); - - - /* - * Routines with C and Z prefixes only - */ - void cblas_chemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const void *alpha, const void *A, - const int lda, const void *X, const int incX, - const void *beta, void *Y, const int incY); - void cblas_chbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const int K, const void *alpha, const void *A, - const int lda, const void *X, const int incX, - const void *beta, void *Y, const int incY); - void cblas_chpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const void *alpha, const void *Ap, - const void *X, const int incX, - const void *beta, void *Y, const int incY); - void cblas_cgeru(const enum CBLAS_ORDER order, const int M, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *A, const int lda); - void cblas_cgerc(const enum CBLAS_ORDER order, const int M, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *A, const int lda); - void cblas_cher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const void *X, const int incX, - void *A, const int lda); - void cblas_chpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const float alpha, const void *X, - const int incX, void *A); - void cblas_cher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *A, const int lda); - void cblas_chpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *Ap); - - void cblas_zhemv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const void *alpha, const void *A, - const int lda, const void *X, const int incX, - const void *beta, void *Y, const int incY); - void cblas_zhbmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const int K, const void *alpha, const void *A, - const int lda, const void *X, const int incX, - const void *beta, void *Y, const int incY); - void cblas_zhpmv(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const void *alpha, const void *Ap, - const void *X, const int incX, - const void *beta, void *Y, const int incY); - void cblas_zgeru(const enum CBLAS_ORDER order, const int M, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *A, const int lda); - void cblas_zgerc(const enum CBLAS_ORDER order, const int M, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *A, const int lda); - void cblas_zher(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const void *X, const int incX, - void *A, const int lda); - void cblas_zhpr(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, - const int N, const double alpha, const void *X, - const int incX, void *A); - void cblas_zher2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *A, const int lda); - void cblas_zhpr2(const enum CBLAS_ORDER order, const enum CBLAS_UPLO Uplo, const int N, - const void *alpha, const void *X, const int incX, - const void *Y, const int incY, void *Ap); - - /* - * =========================================================================== - * Prototypes for level 3 BLAS - * =========================================================================== - */ - - /* - * Routines with standard 4 prefixes (S, D, C, Z) - */ - void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_TRANSPOSE TransB, const int M, const int N, - const int K, const float alpha, const float *A, - const int lda, const float *B, const int ldb, - const float beta, float *C, const int ldc); - void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const int M, const int N, - const float alpha, const float *A, const int lda, - const float *B, const int ldb, const float beta, - float *C, const int ldc); - void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const float alpha, const float *A, const int lda, - const float beta, float *C, const int ldc); - void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const float alpha, const float *A, const int lda, - const float *B, const int ldb, const float beta, - float *C, const int ldc); - void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const float alpha, const float *A, const int lda, - float *B, const int ldb); - void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const float alpha, const float *A, const int lda, - float *B, const int ldb); - - void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_TRANSPOSE TransB, const int M, const int N, - const int K, const double alpha, const double *A, - const int lda, const double *B, const int ldb, - const double beta, double *C, const int ldc); - void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const int M, const int N, - const double alpha, const double *A, const int lda, - const double *B, const int ldb, const double beta, - double *C, const int ldc); - void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const double alpha, const double *A, const int lda, - const double beta, double *C, const int ldc); - void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const double alpha, const double *A, const int lda, - const double *B, const int ldb, const double beta, - double *C, const int ldc); - void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const double alpha, const double *A, const int lda, - double *B, const int ldb); - void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const double alpha, const double *A, const int lda, - double *B, const int ldb); - - void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_TRANSPOSE TransB, const int M, const int N, - const int K, const void *alpha, const void *A, - const int lda, const void *B, const int ldb, - const void *beta, void *C, const int ldc); - void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const int M, const int N, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const void *beta, - void *C, const int ldc); - void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const void *alpha, const void *A, const int lda, - const void *beta, void *C, const int ldc); - void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const void *beta, - void *C, const int ldc); - void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const void *alpha, const void *A, const int lda, - void *B, const int ldb); - void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const void *alpha, const void *A, const int lda, - void *B, const int ldb); - - void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_TRANSPOSE TransB, const int M, const int N, - const int K, const void *alpha, const void *A, - const int lda, const void *B, const int ldb, - const void *beta, void *C, const int ldc); - void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const int M, const int N, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const void *beta, - void *C, const int ldc); - void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const void *alpha, const void *A, const int lda, - const void *beta, void *C, const int ldc); - void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const void *beta, - void *C, const int ldc); - void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const void *alpha, const void *A, const int lda, - void *B, const int ldb); - void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, - const enum CBLAS_DIAG Diag, const int M, const int N, - const void *alpha, const void *A, const int lda, - void *B, const int ldb); - - - /* - * Routines with prefixes C and Z only - */ - void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const int M, const int N, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const void *beta, - void *C, const int ldc); - void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const float alpha, const void *A, const int lda, - const float beta, void *C, const int ldc); - void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const float beta, - void *C, const int ldc); - - void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, - const enum CBLAS_UPLO Uplo, const int M, const int N, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const void *beta, - void *C, const int ldc); - void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const double alpha, const void *A, const int lda, - const double beta, void *C, const int ldc); - void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, - const enum CBLAS_TRANSPOSE Trans, const int N, const int K, - const void *alpha, const void *A, const int lda, - const void *B, const int ldb, const double beta, - void *C, const int ldc); - - void cblas_xerbla(int p, const char *rout, const char *form, ...); - - __END_DECLS - """ - - def blas_header_text(): """C header for the fortran blas interface""" From 5aea945b7ed623e9418876e6c874836599132a85 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 12 Apr 2026 18:36:50 -0500 Subject: [PATCH 06/13] Refactor blas_headers.py to use external .h files --- pytensor/tensor/blas/blas_headers.py | 426 +++------------------------ 1 file changed, 38 insertions(+), 388 deletions(-) diff --git a/pytensor/tensor/blas/blas_headers.py b/pytensor/tensor/blas/blas_headers.py index affd937ee7..d014cbae39 100644 --- a/pytensor/tensor/blas/blas_headers.py +++ b/pytensor/tensor/blas/blas_headers.py @@ -3,8 +3,12 @@ There is no standard name or location for this header, so we just insert it ourselves into the C code. +The static C declarations are stored in .h files under c_code/ for better +IDE support and maintainability. This module reads those files and assembles +the complete header text. """ +import functools import logging from pathlib import Path @@ -13,10 +17,29 @@ _logger = logging.getLogger("pytensor.tensor.blas") +# Directory containing the C header files +_C_CODE_DIR = Path(__file__).parent / "c_code" + +@functools.cache +def _read_c_code_file(filename: str) -> str: + """Read a C code file from the c_code directory.""" + filepath = _C_CODE_DIR / filename + try: + return filepath.read_text(encoding="utf-8") + except OSError as err: + msg = f"Unable to load C header file: {filepath}" + raise OSError(msg) from err + + +@functools.cache def blas_header_text(): - """C header for the fortran blas interface""" + """C header for the fortran blas interface. + Returns the complete BLAS header text including: + - Fortran BLAS declarations (from fortran_blas.h) + - NumPy-based fallback BLAS (if no system BLAS available) + """ blas_code = "" if not config.blas__ldflags: # This code can only be reached by compiling a function with a manually specified GEMM Op. @@ -25,12 +48,9 @@ def blas_header_text(): _logger.warning("Using NumPy C-API based implementation for BLAS functions.") # Include the Numpy version implementation of [sd]gemm_. - current_filedir = Path(__file__).parent - blas_common_filepath = current_filedir / "c_code/alt_blas_common.h" - blas_template_filepath = current_filedir / "c_code/alt_blas_template.c" try: - common_code = blas_common_filepath.read_text(encoding="utf-8") - template_code = blas_template_filepath.read_text(encoding="utf-8") + common_code = _read_c_code_file("alt_blas_common.h") + template_code = _read_c_code_file("alt_blas_template.c") except OSError as err: msg = "Unable to load NumPy implementation of BLAS functions from C source files." raise OSError(msg) from err @@ -50,398 +70,28 @@ def blas_header_text(): blas_code += sblas_code blas_code += dblas_code - header = """ - extern "C" - { - - void xerbla_(char*, void *); - - /***********/ - /* Level 1 */ - /***********/ - - /* Single Precision */ - - void srot_(const int*, float *, const int*, float *, const int*, const float *, const float *); - void srotg_(float *,float *,float *,float *); - void srotm_( const int*, float *, const int*, float *, const int*, const float *); - void srotmg_(float *,float *,float *,const float *, float *); - void sswap_( const int*, float *, const int*, float *, const int*); - void scopy_( const int*, const float *, const int*, float *, const int*); - void saxpy_( const int*, const float *, const float *, const int*, float *, const int*); - float sdot_(const int*, const float *, const int*, const float *, const int*); - void sdot_sub_(const int*, const float *, const int*, const float *, const int*, float *); - void sdsdot_sub_( const int*, const float *, const float *, const int*, const float *, const int*, float *); - void sscal_( const int*, const float *, float *, const int*); - void snrm2_sub_( const int*, const float *, const int*, float *); - void sasum_sub_( const int*, const float *, const int*, float *); - void isamax_sub_( const int*, const float * , const int*, const int*); - - /* Double Precision */ - - void drot_(const int*, double *, const int*, double *, const int*, const double *, const double *); - void drotg_(double *,double *,double *,double *); - void drotm_( const int*, double *, const int*, double *, const int*, const double *); - void drotmg_(double *,double *,double *,const double *, double *); - void dswap_( const int*, double *, const int*, double *, const int*); - void dcopy_( const int*, const double *, const int*, double *, const int*); - void daxpy_( const int*, const double *, const double *, const int*, double *, const int*); - void dswap_( const int*, double *, const int*, double *, const int*); - double ddot_(const int*, const double *, const int*, const double *, const int*); - void dsdot_sub_(const int*, const float *, const int*, const float *, const int*, double *); - void ddot_sub_( const int*, const double *, const int*, const double *, const int*, double *); - void dscal_( const int*, const double *, double *, const int*); - void dnrm2_sub_( const int*, const double *, const int*, double *); - void dasum_sub_( const int*, const double *, const int*, double *); - void idamax_sub_( const int*, const double * , const int*, const int*); - - /* Single Complex Precision */ - - void cswap_( const int*, void *, const int*, void *, const int*); - void ccopy_( const int*, const void *, const int*, void *, const int*); - void caxpy_( const int*, const void *, const void *, const int*, void *, const int*); - void cswap_( const int*, void *, const int*, void *, const int*); - void cdotc_sub_( const int*, const void *, const int*, const void *, const int*, void *); - void cdotu_sub_( const int*, const void *, const int*, const void *, const int*, void *); - void cscal_( const int*, const void *, void *, const int*); - void icamax_sub_( const int*, const void *, const int*, const int*); - void csscal_( const int*, const float *, void *, const int*); - void scnrm2_sub_( const int*, const void *, const int*, float *); - void scasum_sub_( const int*, const void *, const int*, float *); - - /* Double Complex Precision */ - - void zswap_( const int*, void *, const int*, void *, const int*); - void zcopy_( const int*, const void *, const int*, void *, const int*); - void zaxpy_( const int*, const void *, const void *, const int*, void *, const int*); - void zswap_( const int*, void *, const int*, void *, const int*); - void zdotc_sub_( const int*, const void *, const int*, const void *, const int*, void *); - void zdotu_sub_( const int*, const void *, const int*, const void *, const int*, void *); - void zdscal_( const int*, const double *, void *, const int*); - void zscal_( const int*, const void *, void *, const int*); - void dznrm2_sub_( const int*, const void *, const int*, double *); - void dzasum_sub_( const int*, const void *, const int*, double *); - void izamax_sub_( const int*, const void *, const int*, const int*); - - /***********/ - /* Level 2 */ - /***********/ - - /* Single Precision */ - - void sgemv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); - void sgbmv_(char*, const int*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); - void ssymv_(char*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); - void ssbmv_(char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); - void sspmv_(char*, const int*, const float *, const float *, const float *, const int*, const float *, float *, const int*); - void strmv_( char*, char*, char*, const int*, const float *, const int*, float *, const int*); - void stbmv_( char*, char*, char*, const int*, const int*, const float *, const int*, float *, const int*); - void strsv_( char*, char*, char*, const int*, const float *, const int*, float *, const int*); - void stbsv_( char*, char*, char*, const int*, const int*, const float *, const int*, float *, const int*); - void stpmv_( char*, char*, char*, const int*, const float *, float *, const int*); - void stpsv_( char*, char*, char*, const int*, const float *, float *, const int*); - void sger_( const int*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*); - void ssyr_(char*, const int*, const float *, const float *, const int*, float *, const int*); - void sspr_(char*, const int*, const float *, const float *, const int*, float *); - void sspr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *); - void ssyr2_(char*, const int*, const float *, const float *, const int*, const float *, const int*, float *, const int*); - - /* Double Precision */ - - void dgemv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); - void dgbmv_(char*, const int*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); - void dsymv_(char*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); - void dsbmv_(char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); - void dspmv_(char*, const int*, const double *, const double *, const double *, const int*, const double *, double *, const int*); - void dtrmv_( char*, char*, char*, const int*, const double *, const int*, double *, const int*); - void dtbmv_( char*, char*, char*, const int*, const int*, const double *, const int*, double *, const int*); - void dtrsv_( char*, char*, char*, const int*, const double *, const int*, double *, const int*); - void dtbsv_( char*, char*, char*, const int*, const int*, const double *, const int*, double *, const int*); - void dtpmv_( char*, char*, char*, const int*, const double *, double *, const int*); - void dtpsv_( char*, char*, char*, const int*, const double *, double *, const int*); - void dger_( const int*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*); - void dsyr_(char*, const int*, const double *, const double *, const int*, double *, const int*); - void dspr_(char*, const int*, const double *, const double *, const int*, double *); - void dspr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *); - void dsyr2_(char*, const int*, const double *, const double *, const int*, const double *, const int*, double *, const int*); - - /* Single Complex Precision */ - - void cgemv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*); - void cgbmv_(char*, const int*, const int*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*); - void chemv_(char*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*); - void chbmv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*); - void chpmv_(char*, const int*, const void *, const void *, const void *, const int*, const void *, void *, const int*); - void ctrmv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*); - void ctbmv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*); - void ctpmv_( char*, char*, char*, const int*, const void *, void *, const int*); - void ctrsv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*); - void ctbsv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*); - void ctpsv_( char*, char*, char*, const int*, const void *, void *,const int*); - void cgerc_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*); - void cgeru_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*); - void cher_(char*, const int*, const float *, const void *, const int*, void *, const int*); - void cher2_(char*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*); - void chpr_(char*, const int*, const float *, const void *, const int*, void *); - void chpr2_(char*, const int*, const float *, const void *, const int*, const void *, const int*, void *); - - /* Double Complex Precision */ - - void zgemv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*); - void zgbmv_(char*, const int*, const int*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*); - void zhemv_(char*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*); - void zhbmv_(char*, const int*, const int*, const void *, const void *, const int*, const void *, const int*, const void *, void *, const int*); - void zhpmv_(char*, const int*, const void *, const void *, const void *, const int*, const void *, void *, const int*); - void ztrmv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*); - void ztbmv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*); - void ztpmv_( char*, char*, char*, const int*, const void *, void *, const int*); - void ztrsv_( char*, char*, char*, const int*, const void *, const int*, void *, const int*); - void ztbsv_( char*, char*, char*, const int*, const int*, const void *, const int*, void *, const int*); - void ztpsv_( char*, char*, char*, const int*, const void *, void *,const int*); - void zgerc_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*); - void zgeru_( const int*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*); - void zher_(char*, const int*, const double *, const void *, const int*, void *, const int*); - void zher2_(char*, const int*, const void *, const void *, const int*, const void *, const int*, void *, const int*); - void zhpr_(char*, const int*, const double *, const void *, const int*, void *); - void zhpr2_(char*, const int*, const double *, const void *, const int*, const void *, const int*, void *); - - /***********/ - /* Level 3 */ - /***********/ - - /* Single Precision */ - - void sgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); - void ssymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); - void ssyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*); - void ssyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); - void strmm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*); - void strsm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*); - - /* Double Precision */ - - void dgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); - void dsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); - void dsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*); - void dsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); - void dtrmm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*); - void dtrsm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*); - - /* Single Complex Precision */ - - void cgemm_(char*, char*, const int*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); - void csymm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); - void chemm_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); - void csyrk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*); - void cherk_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, float *, const int*); - void csyr2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); - void cher2k_(char*, char*, const int*, const int*, const float *, const float *, const int*, const float *, const int*, const float *, float *, const int*); - void ctrmm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*); - void ctrsm_(char*, char*, char*, char*, const int*, const int*, const float *, const float *, const int*, float *, const int*); - - /* Double Complex Precision */ - - void zgemm_(char*, char*, const int*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); - void zsymm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); - void zhemm_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); - void zsyrk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*); - void zherk_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, double *, const int*); - void zsyr2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); - void zher2k_(char*, char*, const int*, const int*, const double *, const double *, const int*, const double *, const int*, const double *, double *, const int*); - void ztrmm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*); - void ztrsm_(char*, char*, char*, char*, const int*, const int*, const double *, const double *, const int*, double *, const int*); - - } - """ + # Read the Fortran BLAS declarations from the static header file + header = _read_c_code_file("fortran_blas.h") return header + blas_code +@functools.cache def mkl_threads_text(): - """C header for MKL threads interface""" - header = """ - extern "C" - { - int MKL_Set_Num_Threads_Local(int); - #define mkl_set_num_threads_local MKL_Set_Num_Threads_Local - - void MKL_Set_Num_Threads(int); - #define mkl_set_num_threads MKL_Set_Num_Threads - - int MKL_Get_Max_Threads(void); - #define mkl_get_max_threads MKL_Get_Max_Threads - - int MKL_Domain_Set_Num_Threads(int, int); - #define mkl_domain_set_num_threads MKL_Domain_Set_Num_Threads - - int MKL_Domain_Get_Max_Threads(int); - #define mkl_domain_get_max_threads MKL_Domain_Get_Max_Threads - - void MKL_Set_Dynamic(int); - #define mkl_set_dynamic MKL_Set_Dynamic - - int MKL_Get_Dynamic(void); - #define mkl_get_dynamic MKL_Get_Dynamic - } - """ - return header + """C header for MKL threads interface.""" + return _read_c_code_file("mkl_threads.h") +@functools.cache def openblas_threads_text(): - """C header for OpenBLAS threads interface""" - header = """ - extern "C" - { - void openblas_set_num_threads(int); - void goto_set_num_threads(int); - int openblas_get_num_threads(void); - } - """ - return header + """C header for OpenBLAS threads interface.""" + return _read_c_code_file("openblas_threads.h") def blas_header_version(): - # Version 11: Removed obsolete macOS 10.6 sdot bug workaround - return (11,) - - -def ____gemm_code(check_ab, a_init, b_init): - mod = "%" - return f""" - const char * error_string = NULL; - - int type_num = PyArray_DESCR(_x)->type_num; - int type_size = PyArray_ITEMSIZE(_x); // in bytes - - npy_intp* Nx = PyArray_DIMS(_x); - npy_intp* Ny = PyArray_DIMS(_y); - npy_intp* Nz = PyArray_DIMS(_z); - - npy_intp* Sx = PyArray_STRIDES(_x); - npy_intp* Sy = PyArray_STRIDES(_y); - npy_intp* Sz = PyArray_STRIDES(_z); - - size_t sx_0, sx_1, sy_0, sy_1, sz_0, sz_1; - - int unit = 0; - - if (PyArray_NDIM(_x) != 2) goto _dot_execute_fallback; - if (PyArray_NDIM(_y) != 2) goto _dot_execute_fallback; - if (PyArray_NDIM(_z) != 2) goto _dot_execute_fallback; - - {check_ab} - - if ((PyArray_DESCR(_x)->type_num != NPY_DOUBLE) - && (PyArray_DESCR(_x)->type_num != NPY_FLOAT)) - goto _dot_execute_fallback; - - if ((PyArray_DESCR(_y)->type_num != NPY_DOUBLE) - && (PyArray_DESCR(_y)->type_num != NPY_FLOAT)) - goto _dot_execute_fallback; - - if ((PyArray_DESCR(_y)->type_num != NPY_DOUBLE) - && (PyArray_DESCR(_y)->type_num != NPY_FLOAT)) - goto _dot_execute_fallback; - - if ((PyArray_DESCR(_x)->type_num != PyArray_DESCR(_y)->type_num) - ||(PyArray_DESCR(_x)->type_num != PyArray_DESCR(_z)->type_num)) - goto _dot_execute_fallback; - - - if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1])) - {{ - error_string = "Input dimensions do not agree"; - goto _dot_execute_fail; - }} - if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] {mod} type_size) || (Sx[1] {mod} type_size) - || (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] {mod} type_size) || (Sy[1] {mod} type_size) - || (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] {mod} type_size) || (Sz[1] {mod} type_size)) - {{ - goto _dot_execute_fallback; - }} - - /* - encode the stride structure of _x,_y,_z into a single integer - */ - unit |= ((Sx[1] == type_size) ? 0x0 : (Sx[0] == type_size) ? 0x1 : 0x2) << 0; - unit |= ((Sy[1] == type_size) ? 0x0 : (Sy[0] == type_size) ? 0x1 : 0x2) << 4; - unit |= ((Sz[1] == type_size) ? 0x0 : (Sz[0] == type_size) ? 0x1 : 0x2) << 8; - - /* create appropriate strides for malformed matrices that are row or column - * vectors - */ - sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : Nx[1]; - sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : Nx[0]; - sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : Ny[1]; - sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : Ny[0]; - sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : Nz[1]; - sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0]; - - switch (type_num) - {{ - case NPY_FLOAT: - {{ - #define REAL float - float a = {a_init}; - float b = {b_init}; - - float* x = (float*)PyArray_DATA(_x); - float* y = (float*)PyArray_DATA(_y); - float* z = (float*)PyArray_DATA(_z); - - switch(unit) - {{ - case 0x000: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break; - case 0x001: cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break; - case 0x010: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break; - case 0x011: cblas_sgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break; - case 0x100: cblas_sgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break; - case 0x101: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break; - case 0x110: cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break; - case 0x111: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break; - default: goto _dot_execute_fallback; - }}; - #undef REAL - }} - break; - case NPY_DOUBLE: - {{ - #define REAL double - double a = {a_init}; - double b = {b_init}; - - double* x = (double*)PyArray_DATA(_x); - double* y = (double*)PyArray_DATA(_y); - double* z = (double*)PyArray_DATA(_z); - switch(unit) - {{ - case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break; - case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break; - case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break; - case 0x011: cblas_dgemm(CblasRowMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_0); break; - case 0x100: cblas_dgemm(CblasColMajor, CblasTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_1); break; - case 0x101: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_1); break; - case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break; - case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break; - default: goto _dot_execute_fallback; - }}; - #undef REAL - }} - break; - }} - - return 0; //success! - - _dot_execute_fallback: - PyErr_SetString(PyExc_NotImplementedError, - "dot->execute() fallback"); - return -1; - - _dot_execute_fail: - if (error_string == NULL) - PyErr_SetString(PyExc_ValueError, - "dot->execute() can't run on these inputs"); - return -1; + """Return version tuple for cache invalidation. - /* v 1 */ + This version should be bumped when the static header files change. """ + # Version 12: Refactored to use external .h files + return (12,) From 2607a49649ac0309c8271e429720ac06538b6bdd Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 12 Apr 2026 19:09:39 -0500 Subject: [PATCH 07/13] Restore macOS sdot bug, move test code and patch to separate files --- pytensor/tensor/blas/blas_headers.py | 123 +++++++++++++++++- .../macos_sdot_bugfix/macos_sdot_error.h | 19 +++ .../macos_sdot_bugfix/macos_sdot_fix_test.cpp | 32 +++++ .../macos_sdot_bugfix/macos_sdot_test.cpp | 27 ++++ .../macos_sdot_bugfix/macos_sdot_workaround.h | 14 ++ 5 files changed, 211 insertions(+), 4 deletions(-) create mode 100644 pytensor/tensor/blas/c_code/macos_sdot_bugfix/macos_sdot_error.h create mode 100644 pytensor/tensor/blas/c_code/macos_sdot_bugfix/macos_sdot_fix_test.cpp create mode 100644 pytensor/tensor/blas/c_code/macos_sdot_bugfix/macos_sdot_test.cpp create mode 100644 pytensor/tensor/blas/c_code/macos_sdot_bugfix/macos_sdot_workaround.h diff --git a/pytensor/tensor/blas/blas_headers.py b/pytensor/tensor/blas/blas_headers.py index d014cbae39..050b9d9848 100644 --- a/pytensor/tensor/blas/blas_headers.py +++ b/pytensor/tensor/blas/blas_headers.py @@ -5,14 +5,17 @@ The static C declarations are stored in .h files under c_code/ for better IDE support and maintainability. This module reads those files and assembles -the complete header text. +the complete header text, adding dynamic parts like the macOS sdot bug workaround. """ import functools import logging +import os +import sys from pathlib import Path from pytensor.configdefaults import config +from pytensor.link.c.cmodule import GCC_compiler _logger = logging.getLogger("pytensor.tensor.blas") @@ -21,6 +24,101 @@ _C_CODE_DIR = Path(__file__).parent / "c_code" +def detect_macos_sdot_bug(): + """ + Try to detect a bug in the BLAS sdot_ routine on macOS. + + Apple's Accelerate framework has a long-standing bug where the Fortran + interface sdot_() returns incorrect values. The C interface cblas_sdot() + works correctly. This bug has been present since at least macOS 10.6 + and is STILL PRESENT as of macOS 26 (2026). + + This function compiles and runs a test program to detect the bug, + then tests if a workaround (using cblas_sdot instead) works. + + Three attributes of this function will be set: + - detect_macos_sdot_bug.tested: True after first call + - detect_macos_sdot_bug.present: True if bug is detected + - detect_macos_sdot_bug.fix_works: True if cblas_sdot workaround works + """ + _logger.debug("Starting detection of bug in Mac OS BLAS sdot_ routine") + if detect_macos_sdot_bug.tested: + return detect_macos_sdot_bug.present + + if sys.platform != "darwin" or not config.blas__ldflags: + _logger.info("Not Mac OS, no sdot_ bug") + detect_macos_sdot_bug.tested = True + return False + + # This code will return -1 if the dot product did not return + # the right value (30.). + flags = config.blas__ldflags.split() + for f in flags: + # Library directories should also be added as rpath, + # so that they can be loaded even if the environment + # variable LD_LIBRARY_PATH does not contain them + lib_path = os.environ.get("DYLD_FALLBACK_LIBRARY_PATH", "").split(":") + if f.startswith("-L"): + flags.append("-Wl,-rpath," + f[2:]) + # also append those paths to DYLD_FALLBACK_LIBRARY_PATH to + # support libraries that have the wrong install_name + # (such as MKL on canopy installs) + if f[2:] not in lib_path: + lib_path.append(f[2:]) + # this goes into the python process environment that is + # inherited by subprocesses/used by dyld when loading new objects + os.environ["DYLD_FALLBACK_LIBRARY_PATH"] = ":".join(lib_path) + + test_code = _read_c_code_file("macos_sdot_bugfix/macos_sdot_test.cpp") + + _logger.debug("Trying to compile and run test case.") + compilation_ok, run_ok = GCC_compiler.try_compile_tmp( + test_code, tmp_prefix="detect_macos_sdot_bug_", flags=flags, try_run=True + ) + detect_macos_sdot_bug.tested = True + + # If compilation failed, we consider there is a bug, + # and the fix does not work + if not compilation_ok: + _logger.info("Could not compile test case for sdot_.") + detect_macos_sdot_bug.present = True + return True + + if run_ok: + _logger.info("The sdot_ bug is not present on this system.") + detect_macos_sdot_bug.present = False + return False + + # Else, the bug is detected. + _logger.info("The sdot_ bug is present on this system.") + detect_macos_sdot_bug.present = True + + # Then, try a simple fix + test_fix_code = _read_c_code_file("macos_sdot_bugfix/macos_sdot_fix_test.cpp") + + _logger.debug("Trying to compile and run tentative workaround.") + compilation_fix_ok, run_fix_ok = GCC_compiler.try_compile_tmp( + test_fix_code, + tmp_prefix="detect_macos_sdot_bug_testfix_", + flags=flags, + try_run=True, + ) + + _logger.info( + "Status of tentative fix -- compilation OK: %s, works: %s", + compilation_fix_ok, + run_fix_ok, + ) + detect_macos_sdot_bug.fix_works = run_fix_ok + + return detect_macos_sdot_bug.present + + +detect_macos_sdot_bug.tested = False +detect_macos_sdot_bug.present = False +detect_macos_sdot_bug.fix_works = False + + @functools.cache def _read_c_code_file(filename: str) -> str: """Read a C code file from the c_code directory.""" @@ -38,6 +136,7 @@ def blas_header_text(): Returns the complete BLAS header text including: - Fortran BLAS declarations (from fortran_blas.h) + - macOS sdot bug workaround (if applicable) - NumPy-based fallback BLAS (if no system BLAS available) """ blas_code = "" @@ -73,6 +172,14 @@ def blas_header_text(): # Read the Fortran BLAS declarations from the static header file header = _read_c_code_file("fortran_blas.h") + # Add macOS sdot bug workaround if needed + if detect_macos_sdot_bug(): + if detect_macos_sdot_bug.fix_works: + header += _read_c_code_file("macos_sdot_bugfix/macos_sdot_workaround.h") + else: + # Make sure the buggy version of sdot_ is never used + header += _read_c_code_file("macos_sdot_bugfix/macos_sdot_error.h") + return header + blas_code @@ -91,7 +198,15 @@ def openblas_threads_text(): def blas_header_version(): """Return version tuple for cache invalidation. - This version should be bumped when the static header files change. + This version should be bumped when: + - The static header files change + - The sdot bug workaround logic changes """ - # Version 12: Refactored to use external .h files - return (12,) + # Version 13: Restored macOS sdot bug workaround + version = (13,) + if detect_macos_sdot_bug(): + if detect_macos_sdot_bug.fix_works: + version += (1,) + else: + version += (2,) + return version diff --git a/pytensor/tensor/blas/c_code/macos_sdot_bugfix/macos_sdot_error.h b/pytensor/tensor/blas/c_code/macos_sdot_bugfix/macos_sdot_error.h new file mode 100644 index 0000000000..1164f5417c --- /dev/null +++ b/pytensor/tensor/blas/c_code/macos_sdot_bugfix/macos_sdot_error.h @@ -0,0 +1,19 @@ +/* + * macOS sdot_ bug fatal error stub. + * + * When the sdot_ bug is detected but no workaround is available, + * this stub ensures we fail loudly rather than silently returning + * incorrect results. + */ + +static float sdot_(int* Nx, float* x, int* Sx, float* y, int* Sy) +{ + fprintf(stderr, + "FATAL: The implementation of BLAS SDOT " + "routine in your system has a bug that " + "makes it return wrong results.\n" + "You can work around this bug by using a " + "different BLAS library, or disabling BLAS\n"); + assert(0); +} + diff --git a/pytensor/tensor/blas/c_code/macos_sdot_bugfix/macos_sdot_fix_test.cpp b/pytensor/tensor/blas/c_code/macos_sdot_bugfix/macos_sdot_fix_test.cpp new file mode 100644 index 0000000000..9890e7df9f --- /dev/null +++ b/pytensor/tensor/blas/c_code/macos_sdot_bugfix/macos_sdot_fix_test.cpp @@ -0,0 +1,32 @@ +/* + * Test program to verify the macOS BLAS sdot_ bug workaround. + * + * This defines a static sdot_ wrapper that uses cblas_sdot internally, + * then tests if it returns the correct result. The C interface cblas_sdot() + * works correctly even when the Fortran sdot_() is buggy. + * + * Expected result: 0*0 + 1*1 + 2*2 + 3*3 + 4*4 = 30 + * Returns 0 if workaround works, -1 if it fails. + */ + +extern "C" float cblas_sdot(int, float*, int, float*, int); + +static float sdot_(int* Nx, float* x, int* Sx, float* y, int* Sy) +{ + return cblas_sdot(*Nx, x, *Sx, y, *Sy); +} + +int main(int argc, char** argv) +{ + int Nx = 5; + int Sx = 1; + float x[5] = {0, 1, 2, 3, 4}; + float r = sdot_(&Nx, x, &Sx, x, &Sx); + + if ((r - 30.f) > 1e-6 || (r - 30.f) < -1e-6) + { + return -1; + } + return 0; +} + diff --git a/pytensor/tensor/blas/c_code/macos_sdot_bugfix/macos_sdot_test.cpp b/pytensor/tensor/blas/c_code/macos_sdot_bugfix/macos_sdot_test.cpp new file mode 100644 index 0000000000..771a70decd --- /dev/null +++ b/pytensor/tensor/blas/c_code/macos_sdot_bugfix/macos_sdot_test.cpp @@ -0,0 +1,27 @@ +/* + * Test program to detect the macOS BLAS sdot_ bug. + * + * Apple's Accelerate framework has a long-standing bug where the Fortran + * interface sdot_() returns incorrect values. This test computes a simple + * dot product and checks if the result is correct. + * + * Expected result: 0*0 + 1*1 + 2*2 + 3*3 + 4*4 = 30 + * Returns 0 if correct, -1 if bug is present. + */ + +extern "C" float sdot_(int*, float*, int*, float*, int*); + +int main(int argc, char** argv) +{ + int Nx = 5; + int Sx = 1; + float x[5] = {0, 1, 2, 3, 4}; + float r = sdot_(&Nx, x, &Sx, x, &Sx); + + if ((r - 30.f) > 1e-6 || (r - 30.f) < -1e-6) + { + return -1; + } + return 0; +} + diff --git a/pytensor/tensor/blas/c_code/macos_sdot_bugfix/macos_sdot_workaround.h b/pytensor/tensor/blas/c_code/macos_sdot_bugfix/macos_sdot_workaround.h new file mode 100644 index 0000000000..09593a306c --- /dev/null +++ b/pytensor/tensor/blas/c_code/macos_sdot_bugfix/macos_sdot_workaround.h @@ -0,0 +1,14 @@ +/* + * macOS sdot_ bug workaround. + * + * Apple's Accelerate framework has a bug where the Fortran sdot_() interface + * returns incorrect values. This wrapper uses cblas_sdot() instead, which + * works correctly. + */ + +extern "C" float cblas_sdot(int, float*, int, float*, int); +static float sdot_(int* Nx, float* x, int* Sx, float* y, int* Sy) +{ + return cblas_sdot(*Nx, x, *Sx, y, *Sy); +} + From f009725ec3ce45d7c55f7257ed5f58c99ba3769e Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 12 Apr 2026 19:09:39 -0500 Subject: [PATCH 08/13] Move GEMM c code to header files where possible --- pytensor/tensor/blas/c_code/gemm_helper.h | 180 ++++++++++++++++++++++ pytensor/tensor/blas/gemm.py | 101 ++++-------- 2 files changed, 212 insertions(+), 69 deletions(-) create mode 100644 pytensor/tensor/blas/c_code/gemm_helper.h diff --git a/pytensor/tensor/blas/c_code/gemm_helper.h b/pytensor/tensor/blas/c_code/gemm_helper.h new file mode 100644 index 0000000000..304b38c612 --- /dev/null +++ b/pytensor/tensor/blas/c_code/gemm_helper.h @@ -0,0 +1,180 @@ +/* + * GEMM helper functions for PyTensor. + * + * This file contains the core GEMM dispatch logic extracted from the + * Python code generation templates. The goal is to have real C code + * that IDEs can parse, with minimal dynamic parts. + */ + +#ifndef PYTENSOR_GEMM_HELPER_H +#define PYTENSOR_GEMM_HELPER_H + +#include +#include + +/* Include BLAS declarations */ +#include "fortran_blas.h" + +#ifndef MOD +#define MOD % +#endif + +/* + * Compute strides for a contiguous array. + * Used when PyArray_STRIDES returns invalid values (e.g., for size-0 arrays). + */ +static inline void compute_strides(npy_intp *shape, int ndim, int type_size, npy_intp *res) { + res[ndim - 1] = type_size; + for (int i = ndim - 1; i > 0; i--) { + npy_intp s = shape[i]; + res[i - 1] = res[i] * (s > 0 ? s : 1); + } +} + +/* + * Encode the stride structure of three 2D arrays into a single integer. + * + * For each array, we encode: + * 0x0 = row-major (last stride == type_size) or single column + * 0x1 = column-major (first stride == type_size) or single row + * 0x2 = neither (will trigger error) + * + * The encoding is: (x_code << 8) | (y_code << 4) | (z_code << 0) + */ +static inline int pytensor_encode_gemm_strides( + npy_intp *Nx, npy_intp *Sx, + npy_intp *Ny, npy_intp *Sy, + npy_intp *Nz, npy_intp *Sz, + int type_size +) { + int unit = 0; + unit |= ((Sx[1] == type_size || Nx[1] == 1) ? 0x0 : (Sx[0] == type_size || Nx[0] == 1) ? 0x1 : 0x2) << 8; + unit |= ((Sy[1] == type_size || Ny[1] == 1) ? 0x0 : (Sy[0] == type_size || Ny[0] == 1) ? 0x1 : 0x2) << 4; + unit |= ((Sz[1] == type_size || Nz[1] == 1) ? 0x0 : (Sz[0] == type_size || Nz[0] == 1) ? 0x1 : 0x2) << 0; + return unit; +} + +/* + * Compute BLAS-compatible strides from NumPy strides. + * + * BLAS requires leading dimensions to be >= 1 and not smaller than + * the number of elements in that dimension. For vectors or empty + * matrices, we need to compute valid dummy strides. + */ +static inline void pytensor_compute_gemm_strides( + npy_intp *Nx, npy_intp *Sx, int *sx_0, int *sx_1, + npy_intp *Ny, npy_intp *Sy, int *sy_0, int *sy_1, + npy_intp *Nz, npy_intp *Sz, int *sz_0, int *sz_1, + int type_size +) { + *sx_0 = (Nx[0] > 1) ? Sx[0] / type_size : (Nx[1] + 1); + *sx_1 = (Nx[1] > 1) ? Sx[1] / type_size : (Nx[0] + 1); + *sy_0 = (Ny[0] > 1) ? Sy[0] / type_size : (Ny[1] + 1); + *sy_1 = (Ny[1] > 1) ? Sy[1] / type_size : (Ny[0] + 1); + *sz_0 = (Nz[0] > 1) ? Sz[0] / type_size : (Nz[1] + 1); + *sz_1 = (Nz[1] > 1) ? Sz[1] / type_size : (Nz[0] + 1); +} + +/* + * Call sgemm_ with the appropriate transpose flags based on stride encoding. + * + * Returns 0 on success, -1 on error (with Python exception set). + */ +static inline int pytensor_sgemm_dispatch( + int unit, + float *x, float *y, float *z, + float a, float b, + int Nz0, int Nz1, int Nx1, + int sx_0, int sx_1, int sy_0, int sy_1, int sz_0, int sz_1 +) { + char N = 'N'; + char T = 'T'; + + switch (unit) { + case 0x000: sgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break; + case 0x100: sgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break; + case 0x010: sgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break; + case 0x110: sgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break; + case 0x001: sgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break; + case 0x101: sgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break; + case 0x011: sgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break; + case 0x111: sgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break; + default: + PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); + return -1; + } + return 0; +} + +/* + * Call dgemm_ with the appropriate transpose flags based on stride encoding. + * + * Returns 0 on success, -1 on error (with Python exception set). + */ +static inline int pytensor_dgemm_dispatch( + int unit, + double *x, double *y, double *z, + double a, double b, + int Nz0, int Nz1, int Nx1, + int sx_0, int sx_1, int sy_0, int sy_1, int sz_0, int sz_1 +) { + char N = 'N'; + char T = 'T'; + + switch (unit) { + case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break; + case 0x100: dgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break; + case 0x010: dgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break; + case 0x110: dgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break; + case 0x001: dgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break; + case 0x101: dgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break; + case 0x011: dgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break; + case 0x111: dgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break; + default: + PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); + return -1; + } + return 0; +} + +/* + * Check if an array needs to be copied to make it BLAS-compatible. + * + * BLAS requires arrays to have at least one unit stride and valid + * (non-negative, properly aligned) strides. + */ +static inline int pytensor_needs_copy_for_blas(npy_intp *N, npy_intp *S, int type_size) { + return (S[0] < 1) || (S[1] < 1) + || (S[0] MOD type_size) || (S[1] MOD type_size) + || ((S[0] != type_size) && (S[1] != type_size)); +} + +/* + * Ensure an array is BLAS-compatible, copying if necessary. + * + * If the array needs copying, *arr is updated to point to the copy + * and *S is updated to the new strides. The caller must handle + * reference counting appropriately. + * + * Returns 0 on success, -1 on error (with Python exception set). + */ +static inline int pytensor_ensure_blas_compatible( + PyArrayObject **arr, npy_intp *N, npy_intp **S, int type_size +) { + if (pytensor_needs_copy_for_blas(N, *S, type_size)) { + PyArrayObject *copy = (PyArrayObject *)PyArray_Copy(*arr); + if (!copy) { + return -1; + } + Py_DECREF(*arr); + *arr = copy; + *S = PyArray_STRIDES(*arr); + if ((*S)[0] < 1 || (*S)[1] < 1) { + compute_strides(N, 2, type_size, *S); + } + } + return 0; +} + +#endif /* PYTENSOR_GEMM_HELPER_H */ + diff --git a/pytensor/tensor/blas/gemm.py b/pytensor/tensor/blas/gemm.py index 02cdb9785a..96e778c2d0 100644 --- a/pytensor/tensor/blas/gemm.py +++ b/pytensor/tensor/blas/gemm.py @@ -1,3 +1,5 @@ +from pathlib import Path + import numpy as np import pytensor.scalar @@ -13,12 +15,18 @@ view_roots, ) from pytensor.tensor.blas.blas_headers import ( + _read_c_code_file, blas_header_text, blas_header_version, ) from pytensor.tensor.type import DenseTensorType, tensor +def _read_gemm_helper_h(): + """Read the GEMM helper header file.""" + return _read_c_code_file("gemm_helper.h") + + class GemmRelated(COp): """Base class for Gemm and Dot22. @@ -28,20 +36,8 @@ class GemmRelated(COp): __props__: tuple[str, ...] = () def c_support_code(self, **kwargs): - mod_str = """ - #ifndef MOD - #define MOD % - #endif - void compute_strides(npy_intp *shape, int N_shape, int type_size, npy_intp *res) { - int s; - res[N_shape - 1] = type_size; - for (int i = N_shape - 1; i > 0; i--) { - s = shape[i]; - res[i - 1] = res[i] * (s > 0 ? s : 1); - } - } - """ - return blas_header_text() + mod_str + # Include BLAS headers and GEMM helper functions + return blas_header_text() + _read_gemm_helper_h() def c_headers(self, **kwargs): return [] @@ -59,7 +55,9 @@ def c_lib_dirs(self, **kwargs): return ldflags(libs=False, libs_dir=True) def c_header_dirs(self, **kwargs): - return ldflags(libs=False, include_dir=True) + # Include the c_code directory for our header files + c_code_dir = str(Path(__file__).parent / "c_code") + return [c_code_dir, *ldflags(libs=False, include_dir=True)] declare_NS = """ int unit = 0; @@ -166,8 +164,7 @@ def c_header_dirs(self, **kwargs): If some matrices are not contiguous on either dimensions, or have invalid strides, copy their content into a contiguous one */ - if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] MOD type_size) || (Sx[1] MOD type_size) - || ((Sx[0] != type_size) && (Sx[1] != type_size))) + if (pytensor_needs_copy_for_blas(Nx, Sx, type_size)) { PyArrayObject * _x_copy = (PyArrayObject *) PyArray_Copy(%(_x)s); if (!_x_copy) @@ -180,8 +177,7 @@ def c_header_dirs(self, **kwargs): } } - if ((Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] MOD type_size) || (Sy[1] MOD type_size) - || ((Sy[0] != type_size) && (Sy[1] != type_size))) + if (pytensor_needs_copy_for_blas(Ny, Sy, type_size)) { PyArrayObject * _y_copy = (PyArrayObject *) PyArray_Copy(%(_y)s); if (!_y_copy) @@ -194,8 +190,7 @@ def c_header_dirs(self, **kwargs): } } - if ((Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] MOD type_size) || (Sz[1] MOD type_size) - || ((Sz[0] != type_size) && (Sz[1] != type_size))) + if (pytensor_needs_copy_for_blas(Nz, Sz, type_size)) { PyArrayObject * _z_copy = (PyArrayObject *) PyArray_Copy(%(_zout)s); if (!_z_copy) @@ -213,9 +208,7 @@ def c_header_dirs(self, **kwargs): /* encode the stride structure of _x,_y,_zout into a single integer */ - unit |= ((Sx[1] == type_size || Nx[1]==1) ? 0x0 : (Sx[0] == type_size || Nx[0]==1) ? 0x1 : 0x2) << 8; - unit |= ((Sy[1] == type_size || Ny[1]==1) ? 0x0 : (Sy[0] == type_size || Ny[0]==1) ? 0x1 : 0x2) << 4; - unit |= ((Sz[1] == type_size || Nz[1]==1) ? 0x0 : (Sz[0] == type_size || Nz[0]==1) ? 0x1 : 0x2) << 0; + unit = pytensor_encode_gemm_strides(Nx, Sx, Ny, Sy, Nz, Sz, type_size); """ compute_strides = """ @@ -226,12 +219,10 @@ def c_header_dirs(self, **kwargs): * - they are not smaller than the number of elements in the array, * - they are not 0. */ - sx_0 = (Nx[0] > 1) ? Sx[0]/type_size : (Nx[1] + 1); - sx_1 = (Nx[1] > 1) ? Sx[1]/type_size : (Nx[0] + 1); - sy_0 = (Ny[0] > 1) ? Sy[0]/type_size : (Ny[1] + 1); - sy_1 = (Ny[1] > 1) ? Sy[1]/type_size : (Ny[0] + 1); - sz_0 = (Nz[0] > 1) ? Sz[0]/type_size : (Nz[1] + 1); - sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : (Nz[0] + 1); + pytensor_compute_gemm_strides(Nx, Sx, &sx_0, &sx_1, + Ny, Sy, &sy_0, &sy_1, + Nz, Sz, &sz_0, &sz_1, + type_size); """ begin_switch_typenum = """ @@ -250,21 +241,12 @@ def c_header_dirs(self, **kwargs): float* x = (float*)PyArray_DATA(%(_x)s); float* y = (float*)PyArray_DATA(%(_y)s); float* z = (float*)PyArray_DATA(%(_zout)s); - char N = 'N'; - char T = 'T'; int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1]; - switch(unit) - { - case 0x000: sgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_0, &b, z, &sz_0); break; - case 0x100: sgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_0, x, &sx_1, &b, z, &sz_0); break; - case 0x010: sgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_0, &b, z, &sz_0); break; - case 0x110: sgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, &sy_1, x, &sx_1, &b, z, &sz_0); break; - case 0x001: sgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_0, &b, z, &sz_1); break; - case 0x101: sgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_0, &b, z, &sz_1); break; - case 0x011: sgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_0, y, &sy_1, &b, z, &sz_1); break; - case 0x111: sgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, &sx_1, y, &sy_1, &b, z, &sz_1); break; - default: PyErr_SetString(PyExc_ValueError, "some matrix has no unit stride"); %(fail)s; - }; + if (pytensor_sgemm_dispatch(unit, x, y, z, a, b, + Nz0, Nz1, Nx1, + sx_0, sx_1, sy_0, sy_1, sz_0, sz_1) != 0) { + %(fail)s; + } """ case_double = """ @@ -280,31 +262,12 @@ def c_header_dirs(self, **kwargs): double* x = (double*)PyArray_DATA(%(_x)s); double* y = (double*)PyArray_DATA(%(_y)s); double* z = (double*)PyArray_DATA(%(_zout)s); - char N = 'N'; - char T = 'T'; int Nz0 = Nz[0], Nz1 = Nz[1], Nx1 = Nx[1]; - switch(unit) - { - case 0x000: dgemm_(&N, &N, &Nz1, &Nz0, &Nx1, &a, y, - &sy_0, x, &sx_0, &b, z, &sz_0); break; - case 0x100: dgemm_(&N, &T, &Nz1, &Nz0, &Nx1, &a, y, - &sy_0, x, &sx_1, &b, z, &sz_0); break; - case 0x010: dgemm_(&T, &N, &Nz1, &Nz0, &Nx1, &a, y, - &sy_1, x, &sx_0, &b, z, &sz_0); break; - case 0x110: dgemm_(&T, &T, &Nz1, &Nz0, &Nx1, &a, y, - &sy_1, x, &sx_1, &b, z, &sz_0); break; - case 0x001: dgemm_(&T, &T, &Nz0, &Nz1, &Nx1, &a, x, - &sx_0, y, &sy_0, &b, z, &sz_1); break; - case 0x101: dgemm_(&N, &T, &Nz0, &Nz1, &Nx1, &a, x, - &sx_1, y, &sy_0, &b, z, &sz_1); break; - case 0x011: dgemm_(&T, &N, &Nz0, &Nz1, &Nx1, &a, x, - &sx_0, y, &sy_1, &b, z, &sz_1); break; - case 0x111: dgemm_(&N, &N, &Nz0, &Nz1, &Nx1, &a, x, - &sx_1, y, &sy_1, &b, z, &sz_1); break; - default: PyErr_SetString(PyExc_ValueError, - "some matrix has no unit stride"); - %(fail)s; - }; + if (pytensor_dgemm_dispatch(unit, x, y, z, a, b, + Nz0, Nz1, Nx1, + sx_0, sx_1, sy_0, sy_1, sz_0, sz_1) != 0) { + %(fail)s; + } """ end_switch_typenum = """ @@ -343,7 +306,7 @@ def build_gemm_call(self): ) def build_gemm_version(self): - return (14, blas_header_version()) + return (15, blas_header_version()) class Gemm(GemmRelated): From 9cbc8ba04177271f846a2fdb6ba2b2ce0165e007 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 12 Apr 2026 19:20:31 -0500 Subject: [PATCH 09/13] Move GEMV c code to header files where possible --- pytensor/tensor/blas/blas_c.py | 87 +++++----- pytensor/tensor/blas/c_code/gemv_helper.h | 186 ++++++++++++++++++++++ 2 files changed, 222 insertions(+), 51 deletions(-) create mode 100644 pytensor/tensor/blas/c_code/gemv_helper.h diff --git a/pytensor/tensor/blas/blas_c.py b/pytensor/tensor/blas/blas_c.py index 848a730cd6..93270c3442 100644 --- a/pytensor/tensor/blas/blas_c.py +++ b/pytensor/tensor/blas/blas_c.py @@ -1,12 +1,23 @@ +from pathlib import Path + from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType from pytensor.scalar import bool as bool_t from pytensor.tensor.blas._core import ldflags -from pytensor.tensor.blas.blas_headers import blas_header_text, blas_header_version +from pytensor.tensor.blas.blas_headers import ( + _read_c_code_file, + blas_header_text, + blas_header_version, +) from pytensor.tensor.blas.gemv import Gemv from pytensor.tensor.blas.ger import Ger +def _read_gemv_helper_h(): + """Read the GEMV helper header file.""" + return _read_c_code_file("gemv_helper.h") + + class BaseBLAS(COp): def c_libraries(self, **kwargs): return ldflags() @@ -18,10 +29,12 @@ def c_lib_dirs(self, **kwargs): return ldflags(libs=False, libs_dir=True) def c_header_dirs(self, **kwargs): - return ldflags(libs=False, include_dir=True) + # Include the c_code directory for our header files + c_code_dir = str(Path(__file__).parent / "c_code") + return [c_code_dir, *ldflags(libs=False, include_dir=True)] def c_support_code(self, **kwargs): - return blas_header_text() + return blas_header_text() + _read_gemv_helper_h() # ##### ####### ####### @@ -468,7 +481,7 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, must_initialize_y=False, params=N SA1 = -SA1; // Iterate over columns in reverse Sx = -Sx; // Iterate over x in reverse } - } else if ((SA0 < 0) || (SA1 < 0) || ((SA0 != 1) && (SA1 != 1))) + } else if (pytensor_gemv_needs_copy(SA0, SA1)) { // Array isn't contiguous, we have to make a copy // - if the copy is too long, maybe call vector/vector dot on each row instead @@ -494,65 +507,37 @@ def gemv_c_code(y, A, x, z, alpha, beta, fail, must_initialize_y=False, params=N if (is_float) { - z_data[0] = dbeta != 0 ? dbeta * z_data[0] : 0.f; - z_data[0] += alpha * sdot_(&NA1, (float*)(A_data), &SA1, - (float*)x_data, &Sx); + pytensor_sgemv_dot_case(NA1, SA1, + (float*)A_data, (float*)x_data, (float*)z_data, + alpha, fbeta, Sx); } else { - z_data[0] = dbeta != 0 ? dbeta * z_data[0] : 0.; - z_data[0] += alpha * ddot_(&NA1, (double*)(A_data), &SA1, - (double*)x_data, &Sx); + pytensor_dgemv_dot_case(NA1, SA1, + (double*)A_data, (double*)x_data, (double*)z_data, + alpha, dbeta, Sx); } } - else if (SA0 == 1) + else if (SA0 == 1 || SA1 == 1) { - // F-contiguous - char NOTRANS = 'N'; + // C-contiguous or F-contiguous, use GEMV dispatch helper if (is_float) { float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; - sgemv_(&NOTRANS, &NA0, &NA1, - &alpha, - (float*)(A_data), &SA1, - (float*)x_data, &Sx, - &fbeta, - (float*)z_data, &Sz); + if (pytensor_sgemv_dispatch(NA0, NA1, SA0, SA1, + (float*)A_data, (float*)x_data, (float*)z_data, + alpha, fbeta, Sx, Sz) != 0) { + %(fail)s + } } else { double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; - dgemv_(&NOTRANS, &NA0, &NA1, - &alpha, - (double*)(A_data), &SA1, - (double*)x_data, &Sx, - &dbeta, - (double*)z_data, &Sz); - } - } - else if (SA1 == 1) - { - // C-contiguous - char TRANS = 'T'; - if (is_float) - { - float alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; - sgemv_(&TRANS, &NA1, &NA0, - &alpha, - (float*)(A_data), &SA0, - (float*)x_data, &Sx, - &fbeta, - (float*)z_data, &Sz); - } - else - { - double alpha = ((dtype_%(alpha)s*)PyArray_DATA(%(alpha)s))[0]; - dgemv_(&TRANS, &NA1, &NA0, - &alpha, - (double*)(A_data), &SA0, - (double*)x_data, &Sx, - &dbeta, - (double*)z_data, &Sz); + if (pytensor_dgemv_dispatch(NA0, NA1, SA0, SA1, + (double*)A_data, (double*)x_data, (double*)z_data, + alpha, dbeta, Sx, Sz) != 0) { + %(fail)s + } } } else @@ -604,7 +589,7 @@ def c_code(self, node, name, inp, out, sub): return code def c_code_cache_version(self): - return (18, blas_header_version(), must_initialize_y_gemv()) + return (19, blas_header_version(), must_initialize_y_gemv()) cgemv_inplace = CGemv(inplace=True) diff --git a/pytensor/tensor/blas/c_code/gemv_helper.h b/pytensor/tensor/blas/c_code/gemv_helper.h new file mode 100644 index 0000000000..9cb8e3ae04 --- /dev/null +++ b/pytensor/tensor/blas/c_code/gemv_helper.h @@ -0,0 +1,186 @@ +/* + * GEMV helper functions for PyTensor. + * + * This file contains GEMV dispatch logic extracted from Python code generation + * templates. The goal is to have real C code that IDEs can parse. + * + * GEMV computes: z <- beta * y + alpha * dot(A, x) + * where A is a matrix, x and y are vectors. + */ + +#ifndef PYTENSOR_GEMV_HELPER_H +#define PYTENSOR_GEMV_HELPER_H + +#include +#include + +/* Include BLAS declarations */ +#include "fortran_blas.h" + +/* + * Compute BLAS-compatible strides for a matrix. + * + * For row or column matrices, the stride in the dummy dimension doesn't matter, + * but BLAS requires it to be no smaller than the number of elements. + */ +static inline void pytensor_gemv_compute_matrix_strides( + int NA0, int NA1, + npy_intp stride0, npy_intp stride1, + int elemsize, + int *SA0, int *SA1 +) { + *SA0 = (NA0 > 1) ? (stride0 / elemsize) : NA1; + *SA1 = (NA1 > 1) ? (stride1 / elemsize) : NA0; +} + +/* + * Check if a matrix needs to be copied to be BLAS-compatible. + * + * Returns 1 if copy needed, 0 if matrix can be used directly. + * A matrix can be used directly if: + * - It's C-contiguous (SA1 == 1) or F-contiguous (SA0 == 1) + * - OR strides are negative but can be handled by reversing iteration + */ +static inline int pytensor_gemv_needs_copy(int SA0, int SA1) { + /* Can handle negative strides by reversing iteration if one stride is ±1 */ + if ((SA0 < 0 || SA1 < 0) && (abs(SA0) == 1 || abs(SA1) == 1)) { + return 0; + } + /* Otherwise need copy if neither stride is 1 or if strides are negative */ + return (SA0 < 0) || (SA1 < 0) || ((SA0 != 1) && (SA1 != 1)); +} + +/* + * Call sgemv_ for float matrix-vector multiply. + * + * Handles both C-contiguous and F-contiguous layouts. + * For C-contiguous (SA1 == 1): uses transpose + * For F-contiguous (SA0 == 1): no transpose + * + * Returns 0 on success, -1 on error. + */ +static inline int pytensor_sgemv_dispatch( + int NA0, int NA1, + int SA0, int SA1, + float *A_data, float *x_data, float *z_data, + float alpha, float beta, + int Sx, int Sz +) { + if (SA0 == 1) { + /* F-contiguous */ + char NOTRANS = 'N'; + sgemv_(&NOTRANS, &NA0, &NA1, + &alpha, A_data, &SA1, + x_data, &Sx, + &beta, z_data, &Sz); + } else if (SA1 == 1) { + /* C-contiguous */ + char TRANS = 'T'; + sgemv_(&TRANS, &NA1, &NA0, + &alpha, A_data, &SA0, + x_data, &Sx, + &beta, z_data, &Sz); + } else { + PyErr_SetString(PyExc_AssertionError, + "A is neither C nor F-contiguous, it should have been copied"); + return -1; + } + return 0; +} + +/* + * Call dgemv_ for double matrix-vector multiply. + * + * Handles both C-contiguous and F-contiguous layouts. + * For C-contiguous (SA1 == 1): uses transpose + * For F-contiguous (SA0 == 1): no transpose + * + * Returns 0 on success, -1 on error. + */ +static inline int pytensor_dgemv_dispatch( + int NA0, int NA1, + int SA0, int SA1, + double *A_data, double *x_data, double *z_data, + double alpha, double beta, + int Sx, int Sz +) { + if (SA0 == 1) { + /* F-contiguous */ + char NOTRANS = 'N'; + dgemv_(&NOTRANS, &NA0, &NA1, + &alpha, A_data, &SA1, + x_data, &Sx, + &beta, z_data, &Sz); + } else if (SA1 == 1) { + /* C-contiguous */ + char TRANS = 'T'; + dgemv_(&TRANS, &NA1, &NA0, + &alpha, A_data, &SA0, + x_data, &Sx, + &beta, z_data, &Sz); + } else { + PyErr_SetString(PyExc_AssertionError, + "A is neither C nor F-contiguous, it should have been copied"); + return -1; + } + return 0; +} + +/* + * Handle vector-vector dot product case (when A has only 1 row). + * + * Computes: z[0] = beta * z[0] + alpha * dot(A[0,:], x) + * + * This is faster than calling gemv for a single row. + */ +static inline void pytensor_sgemv_dot_case( + int NA1, int SA1, + float *A_data, float *x_data, float *z_data, + float alpha, float beta, int Sx +) { + z_data[0] = (beta != 0.0f) ? beta * z_data[0] : 0.0f; + z_data[0] += alpha * sdot_(&NA1, A_data, &SA1, x_data, &Sx); +} + +static inline void pytensor_dgemv_dot_case( + int NA1, int SA1, + double *A_data, double *x_data, double *z_data, + double alpha, double beta, int Sx +) { + z_data[0] = (beta != 0.0) ? beta * z_data[0] : 0.0; + z_data[0] += alpha * ddot_(&NA1, A_data, &SA1, x_data, &Sx); +} + +/* + * Adjust data pointers and strides for negative stride iteration. + * + * When strides are negative but abs(stride) == 1 for one dimension, + * we can handle this by: + * 1. Jumping to the "first" element (which is at the end of the array) + * 2. Negating the stride so we iterate backwards + * 3. Negating corresponding vector strides + * + * This avoids making a copy of the array. + */ +static inline void pytensor_gemv_handle_negative_strides( + int NA0, int NA1, + int *SA0, int *SA1, + int *Sx, int *Sz, + void **A_data, int elemsize +) { + char *A = (char*)*A_data; + if (*SA0 < 0) { + A += (NA0 - 1) * (*SA0) * elemsize; + *SA0 = -(*SA0); + *Sz = -(*Sz); + } + if (*SA1 < 0) { + A += (NA1 - 1) * (*SA1) * elemsize; + *SA1 = -(*SA1); + *Sx = -(*Sx); + } + *A_data = A; +} + +#endif /* PYTENSOR_GEMV_HELPER_H */ + From 693fb57126ba1f1c78da25d4833a8ae4e04f91e0 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 12 Apr 2026 20:00:50 -0500 Subject: [PATCH 10/13] Move GER c code to header files where possible --- pytensor/tensor/blas/blas_c.py | 151 +++++------------- pytensor/tensor/blas/c_code/ger_helper.h | 187 +++++++++++++++++++++++ 2 files changed, 226 insertions(+), 112 deletions(-) create mode 100644 pytensor/tensor/blas/c_code/ger_helper.h diff --git a/pytensor/tensor/blas/blas_c.py b/pytensor/tensor/blas/blas_c.py index 93270c3442..0fdbeff4a0 100644 --- a/pytensor/tensor/blas/blas_c.py +++ b/pytensor/tensor/blas/blas_c.py @@ -18,6 +18,11 @@ def _read_gemv_helper_h(): return _read_c_code_file("gemv_helper.h") +def _read_ger_helper_h(): + """Read the GER helper header file.""" + return _read_c_code_file("ger_helper.h") + + class BaseBLAS(COp): def c_libraries(self, **kwargs): return ldflags() @@ -34,7 +39,7 @@ def c_header_dirs(self, **kwargs): return [c_code_dir, *ldflags(libs=False, include_dir=True)] def c_support_code(self, **kwargs): - return blas_header_text() + _read_gemv_helper_h() + return blas_header_text() + _read_gemv_helper_h() + _read_ger_helper_h() # ##### ####### ####### @@ -117,56 +122,37 @@ def ger_c_code(A, a, x, y, Z, fail, params): }} if (PyArray_DESCR({Z})->type_num == NPY_FLOAT) {{ - float * zoutdata = (float*)PyArray_DATA({Z}); const float * zdata = (float*)PyArray_DATA({A}); + float * zoutdata = (float*)PyArray_DATA({Z}); const float * xdata = (float*)PyArray_DATA({x}); const float * ydata = (float*)PyArray_DATA({y}); - const float * adata = (float*)PyArray_DATA({a}); - const float alpha = adata[0]; - float tmp, xx; + const float alpha = ((float*)PyArray_DATA({a}))[0]; int Ai = PyArray_STRIDES({A})[0]/sizeof(float); int Aj = PyArray_STRIDES({A})[1]/sizeof(float); int Zi = PyArray_STRIDES({Z})[0]/sizeof(float); int Zj = PyArray_STRIDES({Z})[1]/sizeof(float); int xi = PyArray_STRIDES({x})[0]/sizeof(float); int yj = PyArray_STRIDES({y})[0]/sizeof(float); - for (int i = 0; i < dims[0]; ++i) - {{ - xx = alpha * xdata[xi * i]; - for (int j = 0; j < dims[1]; ++j) - {{ - tmp = zdata[Ai*i+Aj*j]; - tmp += xx * ydata[yj * j]; - zoutdata[Zi*i+Zj*j] = tmp; - }} - }} + pytensor_sger_manual_copy(dims[0], dims[1], + zdata, Ai, Aj, zoutdata, Zi, Zj, + xdata, xi, ydata, yj, alpha); }} else if (PyArray_DESCR({Z})->type_num == NPY_DOUBLE) {{ - double * zoutdata = (double*) PyArray_DATA({Z}); const double * zdata = (double*)PyArray_DATA({A}); + double * zoutdata = (double*) PyArray_DATA({Z}); const double * xdata = (double*)PyArray_DATA({x}); const double * ydata = (double*)PyArray_DATA({y}); - const double * adata = (double*)PyArray_DATA({a}); - const double alpha = adata[0]; - double tmp, xx; - + const double alpha = ((double*)PyArray_DATA({a}))[0]; int Ai = PyArray_STRIDES({A})[0]/sizeof(double); int Aj = PyArray_STRIDES({A})[1]/sizeof(double); int Zi = PyArray_STRIDES({Z})[0]/sizeof(double); int Zj = PyArray_STRIDES({Z})[1]/sizeof(double); int xi = PyArray_STRIDES({x})[0]/sizeof(double); int yj = PyArray_STRIDES({y})[0]/sizeof(double); - for (int i = 0; i < dims[0]; ++i) - {{ - xx = alpha * xdata[xi * i]; - for (int j = 0; j < dims[1]; ++j) - {{ - tmp = zdata[Ai*i+Aj*j]; - tmp += xx * ydata[yj * j]; - zoutdata[Zi*i+Zj*j] = tmp; - }} - }} + pytensor_dger_manual_copy(dims[0], dims[1], + zdata, Ai, Aj, zoutdata, Zi, Zj, + xdata, xi, ydata, yj, alpha); }} else {{ @@ -186,50 +172,33 @@ def ger_c_code(A, a, x, y, Z, fail, params): npy_intp dims[2]; dims[0] = PyArray_DIMS({A})[0]; dims[1] = PyArray_DIMS({A})[1]; - if ((dims[0] * dims[1]) < 100000) + if ((dims[0] * dims[1]) < PYTENSOR_GER_BLAS_THRESHOLD) {{ if (PyArray_DESCR({Z})->type_num == NPY_FLOAT) {{ float * zoutdata = (float*)PyArray_DATA({Z}); const float * xdata = (float*)PyArray_DATA({x}); const float * ydata = (float*)PyArray_DATA({y}); - const float * adata = (float*)PyArray_DATA({a}); - const float alpha = adata[0]; - float tmp, axi; + const float alpha = ((float*)PyArray_DATA({a}))[0]; int Zi = PyArray_STRIDES({Z})[0]/sizeof(float); int Zj = PyArray_STRIDES({Z})[1]/sizeof(float); int xi = PyArray_STRIDES({x})[0]/sizeof(float); int yj = PyArray_STRIDES({y})[0]/sizeof(float); - for (int i = 0; i < dims[0]; ++i) - {{ - axi = alpha * xdata[xi * i]; - for (int j = 0; j < dims[1]; ++j) - {{ - zoutdata[Zi*i+Zj*j] += axi * ydata[yj * j]; - }} - }} + pytensor_sger_manual_inplace(dims[0], dims[1], + zoutdata, Zi, Zj, xdata, xi, ydata, yj, alpha); }} else if (PyArray_DESCR({Z})->type_num == NPY_DOUBLE) {{ double * zoutdata = (double*) PyArray_DATA({Z}); const double * xdata = (double*)PyArray_DATA({x}); const double * ydata = (double*)PyArray_DATA({y}); - const double * adata = (double*)PyArray_DATA({a}); - const double alpha = adata[0]; - double tmp, axi; - + const double alpha = ((double*)PyArray_DATA({a}))[0]; int Zi = PyArray_STRIDES({Z})[0]/sizeof(double); int Zj = PyArray_STRIDES({Z})[1]/sizeof(double); int xi = PyArray_STRIDES({x})[0]/sizeof(double); int yj = PyArray_STRIDES({y})[0]/sizeof(double); - for (int i = 0; i < dims[0]; ++i) - {{ - axi = alpha * xdata[xi * i]; - for (int j = 0; j < dims[1]; ++j) - {{ - zoutdata[Zi*i+Zj*j] += axi * ydata[yj * j]; - }} - }} + pytensor_dger_manual_inplace(dims[0], dims[1], + zoutdata, Zi, Zj, xdata, xi, ydata, yj, alpha); }} }} else @@ -239,81 +208,39 @@ def ger_c_code(A, a, x, y, Z, fail, params): int Sx = PyArray_STRIDES({x})[0] / elemsize; int Sy = PyArray_STRIDES({y})[0] / elemsize; - /* create appropriate strides for Z, if it is a row or column matrix. - * In that case, the value of the stride does not really matter, but - * some versions of BLAS insist that: - * - they are not smaller than the number of elements in the array, - * - they are not 0. - */ - int Sz0 = (Nz0 > 1) ? (PyArray_STRIDES({Z})[0] / elemsize) : (Nz1 + 1); - int Sz1 = (Nz1 > 1) ? (PyArray_STRIDES({Z})[1] / elemsize) : (Nz0 + 1); - dtype_{x}* x_data = (dtype_{x}*) PyArray_DATA({x}); dtype_{y}* y_data = (dtype_{y}*) PyArray_DATA({y}); - // gemv expects pointers to the beginning of memory arrays, - // but numpy provides provides a pointer to the first element, + // ger expects pointers to the beginning of memory arrays, + // but numpy provides a pointer to the first element, // so when the stride is negative, we need to get the last one. if (Sx < 0) x_data += (Nz0 - 1) * Sx; if (Sy < 0) y_data += (Nz1 - 1) * Sy; - if (PyArray_STRIDES({Z})[0] == elemsize) + if (PyArray_DESCR({Z})->type_num == NPY_FLOAT) {{ - if (PyArray_DESCR({Z})->type_num == NPY_FLOAT) - {{ - float alpha = ((dtype_{a}*)PyArray_DATA({a}))[0]; - sger_(&Nz0, &Nz1, &alpha, - (float*)x_data, &Sx, - (float*)y_data, &Sy, - (float*)(PyArray_DATA({Z})), &Sz1); - }} - else if (PyArray_DESCR({Z})->type_num == NPY_DOUBLE) - {{ - double alpha = ((dtype_{a}*)PyArray_DATA({a}))[0]; - dger_(&Nz0, &Nz1, &alpha, - (double*)x_data, &Sx, - (double*)y_data, &Sy, - (double*)(PyArray_DATA({Z})), &Sz1); - - - }} - else {{ - PyErr_SetString(PyExc_NotImplementedError, - "not float nor double"); + float alpha = ((dtype_{a}*)PyArray_DATA({a}))[0]; + if (pytensor_sger_dispatch(Nz0, Nz1, + PyArray_STRIDES({Z})[0], PyArray_STRIDES({Z})[1], elemsize, + (float*)PyArray_DATA({Z}), (float*)x_data, (float*)y_data, + alpha, Sx, Sy) != 0) {{ {fail} }} }} - else if (PyArray_STRIDES({Z})[1] == elemsize) + else if (PyArray_DESCR({Z})->type_num == NPY_DOUBLE) {{ - if (PyArray_DESCR({Z})->type_num == NPY_FLOAT) - {{ - float alpha = ((dtype_{a}*)(PyArray_DATA({a})))[0]; - sger_(&Nz1, &Nz0, &alpha, - (float*)y_data, &Sy, - (float*)x_data, &Sx, - (float*)(PyArray_DATA({Z})), &Sz0); - }} - else if (PyArray_DESCR({Z})->type_num == NPY_DOUBLE) - {{ - double alpha = ((dtype_{a}*)PyArray_DATA({a}))[0]; - dger_(&Nz1, &Nz0, &alpha, - (double*)y_data, &Sy, - (double*)x_data, &Sx, - (double*)(PyArray_DATA({Z})), &Sz0); - }} - else - {{ - PyErr_SetString(PyExc_NotImplementedError, - "not float nor double"); + double alpha = ((dtype_{a}*)PyArray_DATA({a}))[0]; + if (pytensor_dger_dispatch(Nz0, Nz1, + PyArray_STRIDES({Z})[0], PyArray_STRIDES({Z})[1], elemsize, + (double*)PyArray_DATA({Z}), (double*)x_data, (double*)y_data, + alpha, Sx, Sy) != 0) {{ {fail} }} }} else {{ - PyErr_SetString(PyExc_AssertionError, - "A is a double-strided matrix, and should have been copied " - "into a memory-contiguous one."); + PyErr_SetString(PyExc_NotImplementedError, "not float nor double"); {fail} }} }} @@ -334,7 +261,7 @@ def c_code(self, node, name, inp, out, sub): return code def c_code_cache_version(self): - return (11, blas_header_version()) + return (12, blas_header_version()) cger_inplace = CGer(True) diff --git a/pytensor/tensor/blas/c_code/ger_helper.h b/pytensor/tensor/blas/c_code/ger_helper.h new file mode 100644 index 0000000000..a838cd9346 --- /dev/null +++ b/pytensor/tensor/blas/c_code/ger_helper.h @@ -0,0 +1,187 @@ +/* + * GER helper functions for PyTensor. + * + * This file contains GER (rank-1 update) dispatch logic extracted from + * Python code generation templates. + * + * GER computes: A <- A + alpha * x * y^T + * where A is a matrix and x, y are vectors. + */ + +#ifndef PYTENSOR_GER_HELPER_H +#define PYTENSOR_GER_HELPER_H + +#include +#include + +/* Include BLAS declarations */ +#include "fortran_blas.h" + +/* + * Check if a matrix needs to be copied for GER. + * + * GER requires the matrix to have at least one unit stride. + * Returns 1 if copy needed, 0 if matrix can be used directly. + */ +static inline int pytensor_ger_needs_copy(npy_intp stride0, npy_intp stride1, int elemsize) { + return (stride0 < 0) || (stride1 < 0) || + ((stride0 != elemsize) && (stride1 != elemsize)); +} + +/* + * Manual float GER with copy: Z = A + alpha * x * y^T + * + * Used when A needs to be copied (non-contiguous or non-destructive). + * Reads from A (zdata), writes to Z (zoutdata). + */ +static inline void pytensor_sger_manual_copy( + int dims0, int dims1, + const float *zdata, int Ai, int Aj, + float *zoutdata, int Zi, int Zj, + const float *xdata, int xi, + const float *ydata, int yj, + float alpha +) { + for (int i = 0; i < dims0; ++i) { + float xx = alpha * xdata[xi * i]; + for (int j = 0; j < dims1; ++j) { + float tmp = zdata[Ai*i + Aj*j]; + tmp += xx * ydata[yj * j]; + zoutdata[Zi*i + Zj*j] = tmp; + } + } +} + +/* + * Manual double GER with copy: Z = A + alpha * x * y^T + */ +static inline void pytensor_dger_manual_copy( + int dims0, int dims1, + const double *zdata, int Ai, int Aj, + double *zoutdata, int Zi, int Zj, + const double *xdata, int xi, + const double *ydata, int yj, + double alpha +) { + for (int i = 0; i < dims0; ++i) { + double xx = alpha * xdata[xi * i]; + for (int j = 0; j < dims1; ++j) { + double tmp = zdata[Ai*i + Aj*j]; + tmp += xx * ydata[yj * j]; + zoutdata[Zi*i + Zj*j] = tmp; + } + } +} + +/* + * Manual float GER inplace: Z += alpha * x * y^T + * + * Used for small matrices where calling BLAS has overhead. + */ +static inline void pytensor_sger_manual_inplace( + int dims0, int dims1, + float *zoutdata, int Zi, int Zj, + const float *xdata, int xi, + const float *ydata, int yj, + float alpha +) { + for (int i = 0; i < dims0; ++i) { + float axi = alpha * xdata[xi * i]; + for (int j = 0; j < dims1; ++j) { + zoutdata[Zi*i + Zj*j] += axi * ydata[yj * j]; + } + } +} + +/* + * Manual double GER inplace: Z += alpha * x * y^T + */ +static inline void pytensor_dger_manual_inplace( + int dims0, int dims1, + double *zoutdata, int Zi, int Zj, + const double *xdata, int xi, + const double *ydata, int yj, + double alpha +) { + for (int i = 0; i < dims0; ++i) { + double axi = alpha * xdata[xi * i]; + for (int j = 0; j < dims1; ++j) { + zoutdata[Zi*i + Zj*j] += axi * ydata[yj * j]; + } + } +} + +/* + * Call sger_ for float rank-1 update. + * + * Handles both C-contiguous and F-contiguous layouts. + * For F-contiguous (stride[0] == elemsize): sger_(Nz0, Nz1, ..., x, y, ...) + * For C-contiguous (stride[1] == elemsize): sger_(Nz1, Nz0, ..., y, x, ...) + * + * Returns 0 on success, -1 on error. + */ +static inline int pytensor_sger_dispatch( + int Nz0, int Nz1, + npy_intp stride0, npy_intp stride1, + int elemsize, + float *z_data, float *x_data, float *y_data, + float alpha, int Sx, int Sy +) { + /* Compute BLAS-compatible strides */ + int Sz0 = (Nz0 > 1) ? (stride0 / elemsize) : (Nz1 + 1); + int Sz1 = (Nz1 > 1) ? (stride1 / elemsize) : (Nz0 + 1); + + if (stride0 == elemsize) { + /* F-contiguous */ + sger_(&Nz0, &Nz1, &alpha, x_data, &Sx, y_data, &Sy, z_data, &Sz1); + } else if (stride1 == elemsize) { + /* C-contiguous: swap dimensions and vectors */ + sger_(&Nz1, &Nz0, &alpha, y_data, &Sy, x_data, &Sx, z_data, &Sz0); + } else { + PyErr_SetString(PyExc_AssertionError, + "A is a double-strided matrix, and should have been copied " + "into a memory-contiguous one."); + return -1; + } + return 0; +} + +/* + * Call dger_ for double rank-1 update. + * + * Returns 0 on success, -1 on error. + */ +static inline int pytensor_dger_dispatch( + int Nz0, int Nz1, + npy_intp stride0, npy_intp stride1, + int elemsize, + double *z_data, double *x_data, double *y_data, + double alpha, int Sx, int Sy +) { + /* Compute BLAS-compatible strides */ + int Sz0 = (Nz0 > 1) ? (stride0 / elemsize) : (Nz1 + 1); + int Sz1 = (Nz1 > 1) ? (stride1 / elemsize) : (Nz0 + 1); + + if (stride0 == elemsize) { + /* F-contiguous */ + dger_(&Nz0, &Nz1, &alpha, x_data, &Sx, y_data, &Sy, z_data, &Sz1); + } else if (stride1 == elemsize) { + /* C-contiguous: swap dimensions and vectors */ + dger_(&Nz1, &Nz0, &alpha, y_data, &Sy, x_data, &Sx, z_data, &Sz0); + } else { + PyErr_SetString(PyExc_AssertionError, + "A is a double-strided matrix, and should have been copied " + "into a memory-contiguous one."); + return -1; + } + return 0; +} + +/* + * Threshold for using manual loop vs BLAS GER. + * For small matrices, the overhead of calling BLAS is not worth it. + */ +#define PYTENSOR_GER_BLAS_THRESHOLD 100000 + +#endif /* PYTENSOR_GER_HELPER_H */ + From 2d5fcf279094194439a251f6b01f2278d73d29a0 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 12 Apr 2026 20:29:56 -0500 Subject: [PATCH 11/13] Extract GER codegen into a C function, codegen becomes a function call --- pytensor/tensor/blas/blas_c.py | 222 ++----------------------- pytensor/tensor/blas/c_code/ger_op.h | 238 +++++++++++++++++++++++++++ 2 files changed, 256 insertions(+), 204 deletions(-) create mode 100644 pytensor/tensor/blas/c_code/ger_op.h diff --git a/pytensor/tensor/blas/blas_c.py b/pytensor/tensor/blas/blas_c.py index 0fdbeff4a0..84cd8b25ea 100644 --- a/pytensor/tensor/blas/blas_c.py +++ b/pytensor/tensor/blas/blas_c.py @@ -23,6 +23,11 @@ def _read_ger_helper_h(): return _read_c_code_file("ger_helper.h") +def _read_ger_op_h(): + """Read the complete GER operation header file.""" + return _read_c_code_file("ger_op.h") + + class BaseBLAS(COp): def c_libraries(self, **kwargs): return ldflags() @@ -39,7 +44,7 @@ def c_header_dirs(self, **kwargs): return [c_code_dir, *ldflags(libs=False, include_dir=True)] def c_support_code(self, **kwargs): - return blas_header_text() + _read_gemv_helper_h() + _read_ger_helper_h() + return blas_header_text() + _read_gemv_helper_h() + _read_ger_op_h() # ##### ####### ####### @@ -47,209 +52,13 @@ def c_support_code(self, **kwargs): # ##### ####### ####### -def ger_c_code(A, a, x, y, Z, fail, params): - return f""" - - int elemsize ; - - if (PyArray_NDIM({A}) != 2) - {{PyErr_SetString(PyExc_NotImplementedError, "rank(A) != 2"); {fail};}} - if (PyArray_NDIM({x}) != 1) - {{PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 1"); {fail};}} - if (PyArray_NDIM({y}) != 1) - {{PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 1"); {fail};}} - if (PyArray_NDIM({a}) != 0) - {{PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 0"); {fail};}} - - if (PyArray_DESCR({A})->type_num != PyArray_DESCR({x})->type_num) - {{ PyErr_SetString(PyExc_TypeError, "A vs. x"); {fail}; }} - if (PyArray_DESCR({A})->type_num != PyArray_DESCR({y})->type_num) - {{ PyErr_SetString(PyExc_TypeError, "A vs. y"); {fail}; }} - - if (PyArray_DIMS({A})[0] != PyArray_DIMS({x})[0]) - {{ - PyErr_SetString(PyExc_ValueError, - "Shape mismatch: A.shape[0] != x.shape[0]"); - {fail}; - }} - if (PyArray_DIMS({A})[1] != PyArray_DIMS({y})[0]) - {{ - PyErr_SetString(PyExc_ValueError, - "Shape mismatch: A.shape[1] != y.shape[0]"); - {fail}; - }} - - if (PyArray_DESCR({A})->type_num == NPY_DOUBLE) {{ elemsize = 8; }} - else if (PyArray_DESCR({A})->type_num == NPY_FLOAT) {{ elemsize = 4;}} - else - {{ - PyErr_SetString(PyExc_NotImplementedError, "complex CGer"); - {fail}; - }} - - // copy A if !self.destructive or A is fully strided - if (!{params}->destructive - || (PyArray_STRIDES({A})[0] < 0) - || (PyArray_STRIDES({A})[1] < 0) - || ((PyArray_STRIDES({A})[0] != elemsize) - && (PyArray_STRIDES({A})[1] != elemsize))) - {{ - npy_intp dims[2]; - dims[0] = PyArray_DIMS({A})[0]; - dims[1] = PyArray_DIMS({A})[1]; - - if ((NULL == {Z}) - || (PyArray_DIMS({Z})[0] != PyArray_DIMS({A})[0]) - || (PyArray_DIMS({Z})[1] != PyArray_DIMS({A})[1]) - || (PyArray_STRIDES({Z})[0] < 0) - || (PyArray_STRIDES({Z})[1] < 0) - || ((PyArray_STRIDES({Z})[0] != elemsize) - && (PyArray_STRIDES({Z})[1] != elemsize))) - {{ - Py_XDECREF({Z}); - {Z} = (PyArrayObject*) PyArray_SimpleNew(2, dims, - PyArray_TYPE({A})); - if(!{Z}) {{ - PyErr_SetString(PyExc_MemoryError, - "failed to alloc ger output"); - {fail} - }} - }} - if ({Z} == {A}) - {{ - PyErr_SetString(PyExc_AssertionError, "{Z} != {A}"); - {fail} - }} - if (PyArray_DESCR({Z})->type_num == NPY_FLOAT) - {{ - const float * zdata = (float*)PyArray_DATA({A}); - float * zoutdata = (float*)PyArray_DATA({Z}); - const float * xdata = (float*)PyArray_DATA({x}); - const float * ydata = (float*)PyArray_DATA({y}); - const float alpha = ((float*)PyArray_DATA({a}))[0]; - int Ai = PyArray_STRIDES({A})[0]/sizeof(float); - int Aj = PyArray_STRIDES({A})[1]/sizeof(float); - int Zi = PyArray_STRIDES({Z})[0]/sizeof(float); - int Zj = PyArray_STRIDES({Z})[1]/sizeof(float); - int xi = PyArray_STRIDES({x})[0]/sizeof(float); - int yj = PyArray_STRIDES({y})[0]/sizeof(float); - pytensor_sger_manual_copy(dims[0], dims[1], - zdata, Ai, Aj, zoutdata, Zi, Zj, - xdata, xi, ydata, yj, alpha); - }} - else if (PyArray_DESCR({Z})->type_num == NPY_DOUBLE) - {{ - const double * zdata = (double*)PyArray_DATA({A}); - double * zoutdata = (double*) PyArray_DATA({Z}); - const double * xdata = (double*)PyArray_DATA({x}); - const double * ydata = (double*)PyArray_DATA({y}); - const double alpha = ((double*)PyArray_DATA({a}))[0]; - int Ai = PyArray_STRIDES({A})[0]/sizeof(double); - int Aj = PyArray_STRIDES({A})[1]/sizeof(double); - int Zi = PyArray_STRIDES({Z})[0]/sizeof(double); - int Zj = PyArray_STRIDES({Z})[1]/sizeof(double); - int xi = PyArray_STRIDES({x})[0]/sizeof(double); - int yj = PyArray_STRIDES({y})[0]/sizeof(double); - pytensor_dger_manual_copy(dims[0], dims[1], - zdata, Ai, Aj, zoutdata, Zi, Zj, - xdata, xi, ydata, yj, alpha); - }} - else - {{ - PyErr_SetString(PyExc_AssertionError, - "neither float nor double dtype"); - {fail} - }} - }} - else - {{ - if ({Z} != {A}) - {{ - if ({Z}) {{ Py_DECREF({Z}); }} - {Z} = {A}; - Py_INCREF({Z}); - }} - npy_intp dims[2]; - dims[0] = PyArray_DIMS({A})[0]; - dims[1] = PyArray_DIMS({A})[1]; - if ((dims[0] * dims[1]) < PYTENSOR_GER_BLAS_THRESHOLD) - {{ - if (PyArray_DESCR({Z})->type_num == NPY_FLOAT) - {{ - float * zoutdata = (float*)PyArray_DATA({Z}); - const float * xdata = (float*)PyArray_DATA({x}); - const float * ydata = (float*)PyArray_DATA({y}); - const float alpha = ((float*)PyArray_DATA({a}))[0]; - int Zi = PyArray_STRIDES({Z})[0]/sizeof(float); - int Zj = PyArray_STRIDES({Z})[1]/sizeof(float); - int xi = PyArray_STRIDES({x})[0]/sizeof(float); - int yj = PyArray_STRIDES({y})[0]/sizeof(float); - pytensor_sger_manual_inplace(dims[0], dims[1], - zoutdata, Zi, Zj, xdata, xi, ydata, yj, alpha); - }} - else if (PyArray_DESCR({Z})->type_num == NPY_DOUBLE) - {{ - double * zoutdata = (double*) PyArray_DATA({Z}); - const double * xdata = (double*)PyArray_DATA({x}); - const double * ydata = (double*)PyArray_DATA({y}); - const double alpha = ((double*)PyArray_DATA({a}))[0]; - int Zi = PyArray_STRIDES({Z})[0]/sizeof(double); - int Zj = PyArray_STRIDES({Z})[1]/sizeof(double); - int xi = PyArray_STRIDES({x})[0]/sizeof(double); - int yj = PyArray_STRIDES({y})[0]/sizeof(double); - pytensor_dger_manual_inplace(dims[0], dims[1], - zoutdata, Zi, Zj, xdata, xi, ydata, yj, alpha); - }} - }} - else - {{ - int Nz0 = PyArray_DIMS({Z})[0]; - int Nz1 = PyArray_DIMS({Z})[1]; - int Sx = PyArray_STRIDES({x})[0] / elemsize; - int Sy = PyArray_STRIDES({y})[0] / elemsize; - - dtype_{x}* x_data = (dtype_{x}*) PyArray_DATA({x}); - dtype_{y}* y_data = (dtype_{y}*) PyArray_DATA({y}); - // ger expects pointers to the beginning of memory arrays, - // but numpy provides a pointer to the first element, - // so when the stride is negative, we need to get the last one. - if (Sx < 0) - x_data += (Nz0 - 1) * Sx; - if (Sy < 0) - y_data += (Nz1 - 1) * Sy; - - if (PyArray_DESCR({Z})->type_num == NPY_FLOAT) - {{ - float alpha = ((dtype_{a}*)PyArray_DATA({a}))[0]; - if (pytensor_sger_dispatch(Nz0, Nz1, - PyArray_STRIDES({Z})[0], PyArray_STRIDES({Z})[1], elemsize, - (float*)PyArray_DATA({Z}), (float*)x_data, (float*)y_data, - alpha, Sx, Sy) != 0) {{ - {fail} - }} - }} - else if (PyArray_DESCR({Z})->type_num == NPY_DOUBLE) - {{ - double alpha = ((dtype_{a}*)PyArray_DATA({a}))[0]; - if (pytensor_dger_dispatch(Nz0, Nz1, - PyArray_STRIDES({Z})[0], PyArray_STRIDES({Z})[1], elemsize, - (double*)PyArray_DATA({Z}), (double*)x_data, (double*)y_data, - alpha, Sx, Sy) != 0) {{ - {fail} - }} - }} - else - {{ - PyErr_SetString(PyExc_NotImplementedError, "not float nor double"); - {fail} - }} - }} - }} +class CGer(BaseBLAS, Ger): + """C implementation of GER (rank-1 update): Z = A + alpha * outer(x, y). + This uses the pytensor_ger() function from ger_op.h which handles + all validation, allocation, and computation in C. """ - -class CGer(BaseBLAS, Ger): params_type = ParamsType( destructive=bool_t, ) @@ -257,11 +66,16 @@ class CGer(BaseBLAS, Ger): def c_code(self, node, name, inp, out, sub): A, a, x, y = inp (Z,) = out - code = ger_c_code(A, a, x, y, Z, fail=sub["fail"], params=sub["params"]) - return code + fail = sub["fail"] + params = sub["params"] + return f""" + if (pytensor_ger({A}, {a}, {x}, {y}, &{Z}, {params}->destructive) != 0) {{ + {fail} + }} + """ def c_code_cache_version(self): - return (12, blas_header_version()) + return (13, blas_header_version()) cger_inplace = CGer(True) diff --git a/pytensor/tensor/blas/c_code/ger_op.h b/pytensor/tensor/blas/c_code/ger_op.h new file mode 100644 index 0000000000..83891aeda5 --- /dev/null +++ b/pytensor/tensor/blas/c_code/ger_op.h @@ -0,0 +1,238 @@ +/* + * Complete GER operation for PyTensor. + * + * This file contains a top-level GER function that handles: + * - Input validation (rank, dtype, shape checks) + * - Output allocation (with proper reference counting) + * - Computation dispatch (manual loops or BLAS) + * + * This enables minimal Python codegen - just a single function call. + */ + +#ifndef PYTENSOR_GER_OP_H +#define PYTENSOR_GER_OP_H + +#include +#include + +/* Include the helper functions */ +#include "ger_helper.h" + +/* + * Perform complete GER operation: Z = A + alpha * outer(x, y) + * + * Parameters: + * A - Input matrix (2D) + * a - Scalar alpha (0D) + * x - Input vector (1D), length must match A.shape[0] + * y - Input vector (1D), length must match A.shape[1] + * Z_ptr - Pointer to output array pointer (will be set/updated) + * destructive - If true and A is contiguous, operate in-place on A + * + * Returns: + * 0 on success, -1 on error (Python exception set) + * + * Note: Caller is responsible for reference counting. If *Z_ptr is changed, + * the old value is DECREF'd and the new value is returned with a new reference. + */ +static int pytensor_ger( + PyArrayObject *A, + PyArrayObject *a, + PyArrayObject *x, + PyArrayObject *y, + PyArrayObject **Z_ptr, + int destructive +) { + int elemsize; + npy_intp dims[2]; + PyArrayObject *Z = *Z_ptr; + + /* Validate ranks */ + if (PyArray_NDIM(A) != 2) { + PyErr_SetString(PyExc_NotImplementedError, "rank(A) != 2"); + return -1; + } + if (PyArray_NDIM(x) != 1) { + PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 1"); + return -1; + } + if (PyArray_NDIM(y) != 1) { + PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 1"); + return -1; + } + if (PyArray_NDIM(a) != 0) { + PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 0"); + return -1; + } + + /* Validate dtypes match */ + if (PyArray_DESCR(A)->type_num != PyArray_DESCR(x)->type_num) { + PyErr_SetString(PyExc_TypeError, "A vs. x dtype mismatch"); + return -1; + } + if (PyArray_DESCR(A)->type_num != PyArray_DESCR(y)->type_num) { + PyErr_SetString(PyExc_TypeError, "A vs. y dtype mismatch"); + return -1; + } + + /* Validate shapes */ + if (PyArray_DIMS(A)[0] != PyArray_DIMS(x)[0]) { + PyErr_SetString(PyExc_ValueError, "Shape mismatch: A.shape[0] != x.shape[0]"); + return -1; + } + if (PyArray_DIMS(A)[1] != PyArray_DIMS(y)[0]) { + PyErr_SetString(PyExc_ValueError, "Shape mismatch: A.shape[1] != y.shape[0]"); + return -1; + } + + /* Determine element size */ + if (PyArray_DESCR(A)->type_num == NPY_DOUBLE) { + elemsize = 8; + } else if (PyArray_DESCR(A)->type_num == NPY_FLOAT) { + elemsize = 4; + } else { + PyErr_SetString(PyExc_NotImplementedError, "complex CGer not implemented"); + return -1; + } + + dims[0] = PyArray_DIMS(A)[0]; + dims[1] = PyArray_DIMS(A)[1]; + + /* Decide: copy A or operate in-place */ + if (!destructive || pytensor_ger_needs_copy( + PyArray_STRIDES(A)[0], PyArray_STRIDES(A)[1], elemsize)) { + /* + * Need to copy: either non-destructive mode or A has bad strides. + * Allocate Z if needed, then copy A into Z and compute. + */ + int need_alloc = (Z == NULL) + || (PyArray_DIMS(Z)[0] != dims[0]) + || (PyArray_DIMS(Z)[1] != dims[1]) + || pytensor_ger_needs_copy( + PyArray_STRIDES(Z)[0], PyArray_STRIDES(Z)[1], elemsize); + + if (need_alloc) { + Py_XDECREF(Z); + Z = (PyArrayObject *)PyArray_SimpleNew(2, dims, PyArray_TYPE(A)); + if (!Z) { + PyErr_SetString(PyExc_MemoryError, "failed to alloc ger output"); + return -1; + } + *Z_ptr = Z; + } + + if (Z == A) { + PyErr_SetString(PyExc_AssertionError, "Z should not be A in copy path"); + return -1; + } + + /* Copy A to Z and add outer product */ + if (PyArray_DESCR(Z)->type_num == NPY_FLOAT) { + const float *zdata = (float *)PyArray_DATA(A); + float *zoutdata = (float *)PyArray_DATA(Z); + const float *xdata = (float *)PyArray_DATA(x); + const float *ydata = (float *)PyArray_DATA(y); + float alpha = ((float *)PyArray_DATA(a))[0]; + int Ai = PyArray_STRIDES(A)[0] / sizeof(float); + int Aj = PyArray_STRIDES(A)[1] / sizeof(float); + int Zi = PyArray_STRIDES(Z)[0] / sizeof(float); + int Zj = PyArray_STRIDES(Z)[1] / sizeof(float); + int xi = PyArray_STRIDES(x)[0] / sizeof(float); + int yj = PyArray_STRIDES(y)[0] / sizeof(float); + pytensor_sger_manual_copy(dims[0], dims[1], + zdata, Ai, Aj, zoutdata, Zi, Zj, + xdata, xi, ydata, yj, alpha); + } else { + const double *zdata = (double *)PyArray_DATA(A); + double *zoutdata = (double *)PyArray_DATA(Z); + const double *xdata = (double *)PyArray_DATA(x); + const double *ydata = (double *)PyArray_DATA(y); + double alpha = ((double *)PyArray_DATA(a))[0]; + int Ai = PyArray_STRIDES(A)[0] / sizeof(double); + int Aj = PyArray_STRIDES(A)[1] / sizeof(double); + int Zi = PyArray_STRIDES(Z)[0] / sizeof(double); + int Zj = PyArray_STRIDES(Z)[1] / sizeof(double); + int xi = PyArray_STRIDES(x)[0] / sizeof(double); + int yj = PyArray_STRIDES(y)[0] / sizeof(double); + pytensor_dger_manual_copy(dims[0], dims[1], + zdata, Ai, Aj, zoutdata, Zi, Zj, + xdata, xi, ydata, yj, alpha); + } + } else { + /* + * Destructive mode with good strides: operate in-place on A. + */ + if (Z != A) { + Py_XDECREF(Z); + Z = A; + Py_INCREF(Z); + *Z_ptr = Z; + } + + if ((dims[0] * dims[1]) < PYTENSOR_GER_BLAS_THRESHOLD) { + /* Small matrix: use manual loop */ + if (PyArray_DESCR(Z)->type_num == NPY_FLOAT) { + float *zoutdata = (float *)PyArray_DATA(Z); + const float *xdata = (float *)PyArray_DATA(x); + const float *ydata = (float *)PyArray_DATA(y); + float alpha = ((float *)PyArray_DATA(a))[0]; + int Zi = PyArray_STRIDES(Z)[0] / sizeof(float); + int Zj = PyArray_STRIDES(Z)[1] / sizeof(float); + int xi = PyArray_STRIDES(x)[0] / sizeof(float); + int yj = PyArray_STRIDES(y)[0] / sizeof(float); + pytensor_sger_manual_inplace(dims[0], dims[1], + zoutdata, Zi, Zj, xdata, xi, ydata, yj, alpha); + } else { + double *zoutdata = (double *)PyArray_DATA(Z); + const double *xdata = (double *)PyArray_DATA(x); + const double *ydata = (double *)PyArray_DATA(y); + double alpha = ((double *)PyArray_DATA(a))[0]; + int Zi = PyArray_STRIDES(Z)[0] / sizeof(double); + int Zj = PyArray_STRIDES(Z)[1] / sizeof(double); + int xi = PyArray_STRIDES(x)[0] / sizeof(double); + int yj = PyArray_STRIDES(y)[0] / sizeof(double); + pytensor_dger_manual_inplace(dims[0], dims[1], + zoutdata, Zi, Zj, xdata, xi, ydata, yj, alpha); + } + } else { + /* Large matrix: use BLAS */ + int Nz0 = dims[0]; + int Nz1 = dims[1]; + int Sx = PyArray_STRIDES(x)[0] / elemsize; + int Sy = PyArray_STRIDES(y)[0] / elemsize; + + /* Handle negative strides */ + void *x_data = PyArray_DATA(x); + void *y_data = PyArray_DATA(y); + if (Sx < 0) { + x_data = (char *)x_data + (Nz0 - 1) * Sx * elemsize; + } + if (Sy < 0) { + y_data = (char *)y_data + (Nz1 - 1) * Sy * elemsize; + } + + if (PyArray_DESCR(Z)->type_num == NPY_FLOAT) { + float alpha = ((float *)PyArray_DATA(a))[0]; + if (pytensor_sger_dispatch(Nz0, Nz1, + PyArray_STRIDES(Z)[0], PyArray_STRIDES(Z)[1], elemsize, + (float *)PyArray_DATA(Z), (float *)x_data, (float *)y_data, + alpha, Sx, Sy) != 0) { + return -1; + } + } else { + double alpha = ((double *)PyArray_DATA(a))[0]; + if (pytensor_dger_dispatch(Nz0, Nz1, + PyArray_STRIDES(Z)[0], PyArray_STRIDES(Z)[1], elemsize, + (double *)PyArray_DATA(Z), (double *)x_data, (double *)y_data, + alpha, Sx, Sy) != 0) { + return -1; + } + } + } + } + + return 0; +} + +#endif /* PYTENSOR_GER_OP_H */ + From 26909a1c8c5d61cf9490cb12a8a8ba54bbdf67ba Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 19 Apr 2026 00:55:52 -0500 Subject: [PATCH 12/13] Don't cache blas_header_text --- pytensor/tensor/blas/blas_headers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytensor/tensor/blas/blas_headers.py b/pytensor/tensor/blas/blas_headers.py index 050b9d9848..59c861e9aa 100644 --- a/pytensor/tensor/blas/blas_headers.py +++ b/pytensor/tensor/blas/blas_headers.py @@ -130,7 +130,6 @@ def _read_c_code_file(filename: str) -> str: raise OSError(msg) from err -@functools.cache def blas_header_text(): """C header for the fortran blas interface. From 8b69f18eef562b5fb8e2bcfc4bc89264d078292a Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sun, 19 Apr 2026 01:13:54 -0500 Subject: [PATCH 13/13] mypy problems? IGNORE IGNORE IGNORE --- pytensor/tensor/blas/blas_c.py | 2 +- pytensor/tensor/blas/blas_headers.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytensor/tensor/blas/blas_c.py b/pytensor/tensor/blas/blas_c.py index 84cd8b25ea..41ab384441 100644 --- a/pytensor/tensor/blas/blas_c.py +++ b/pytensor/tensor/blas/blas_c.py @@ -393,4 +393,4 @@ def must_initialize_y_gemv(): return must_initialize_y_gemv._force_init_beta -must_initialize_y_gemv._force_init_beta = None +must_initialize_y_gemv._force_init_beta = None # type: ignore[attr-defined] diff --git a/pytensor/tensor/blas/blas_headers.py b/pytensor/tensor/blas/blas_headers.py index 59c861e9aa..55f1adcb41 100644 --- a/pytensor/tensor/blas/blas_headers.py +++ b/pytensor/tensor/blas/blas_headers.py @@ -114,9 +114,9 @@ def detect_macos_sdot_bug(): return detect_macos_sdot_bug.present -detect_macos_sdot_bug.tested = False -detect_macos_sdot_bug.present = False -detect_macos_sdot_bug.fix_works = False +detect_macos_sdot_bug.tested = False # type: ignore[attr-defined] +detect_macos_sdot_bug.present = False # type: ignore[attr-defined] +detect_macos_sdot_bug.fix_works = False # type: ignore[attr-defined] @functools.cache