diff --git a/CHANGELOG.md b/CHANGELOG.md index f55db47..7f7e10d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,11 +4,21 @@ ### Added - **New `hypyp.sync` module**: Modular architecture for connectivity metrics - - Extracted 9 connectivity metrics into separate classes: `PLV`, `CCorr`, `ACorr`, `Coh`, `ImCoh`, `PLI`, `WPLI`, `EnvCorr`, `PowCorr` + - Extracted 9 connectivity metrics into separate classes: `PLV`, `CCorr`, `ACCorr`, `Coh`, `ImCoh`, `PLI`, `WPLI`, `EnvCorr`, `PowCorr` - `BaseMetric` abstract class for uniform interface across all metrics - - `get_metric(mode, backend)` function for easy metric instantiation - - Backend support infrastructure (numpy default, with future support for numba/torch) + - `get_metric(mode, optimization)` function for easy metric instantiation - Helper functions: `multiply_conjugate`, `multiply_conjugate_time`, `multiply_product` +- **GPU and numba backends for all 9 sync metrics**: + - numba JIT with `prange`: PLV, CCorr, Coh, ImCoh, PLI, wPLI, EnvCorr, PowCorr + - PyTorch (MPS/CUDA/CPU) via batched einsum: all 9 metrics + - Metal compute shaders (Apple Silicon): PLI, wPLI, ACCorr + - CUDA raw kernels via CuPy (NVIDIA GPUs): all 9 metrics +- Benchmark-driven `AUTO_PRIORITY` table for `optimization='auto'`, compiled from + Mac M4 Max (131 runs) and Narval A100 (111 runs) benchmarks +- `priority` parameter on `get_metric()` and `compute_sync()` for custom backend ordering +- `hypyp/sync/kernels/` submodule with Metal and CUDA dispatch infrastructure +- New optional dependencies: `pyobjc-framework-Metal` (Apple), `cupy-cuda12x` (NVIDIA) +- `multiply_conjugate_torch` and `multiply_conjugate_time_torch` GPU helpers ### Changed - **BREAKING**: `accorr` metric now returns raw connectivity values with shape `(n_epoch, n_freq, 2*n_ch, 2*n_ch)` like all other metrics. The `swapaxes` and `epochs_average` operations are now handled by `compute_sync()` instead of being applied inside the metric. @@ -18,7 +28,7 @@ - `_multiply_conjugate()` in analyses.py - use `hypyp.sync.multiply_conjugate` instead (will be removed in 1.0.0) - `_multiply_conjugate_time()` in analyses.py - use `hypyp.sync.multiply_conjugate_time` instead (will be removed in 1.0.0) - `_multiply_product()` in analyses.py - use `hypyp.sync.multiply_product` instead (will be removed in 1.0.0) -- `_accorr_hybrid()` in analyses.py - use `hypyp.sync.ACorr` instead (will be removed in 1.0.0) +- `_accorr_hybrid()` in analyses.py - use `hypyp.sync.ACCorr` instead (will be removed in 1.0.0) ## [0.5.0b13] - 2025-09-18 diff --git a/hypyp/analyses.py b/hypyp/analyses.py index a4b8201..d13e058 100644 --- a/hypyp/analyses.py +++ b/hypyp/analyses.py @@ -439,7 +439,8 @@ def pair_connectivity(data: Union[list, np.ndarray], sampling_rate: int, def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = True, - optimization: Optional[str] = None) -> np.ndarray: + optimization: Optional[str] = None, + priority: Optional[list] = None) -> np.ndarray: """ Computes frequency-domain connectivity measures from analytic signals. @@ -547,7 +548,7 @@ def compute_sync(complex_signal: np.ndarray, mode: str, epochs_average: bool = T # Get the metric from the sync module try: - metric = get_metric(mode_normalized, optimization=optimization) + metric = get_metric(mode_normalized, optimization=optimization, priority=priority) con = metric.compute(complex_signal, n_samp, transpose_axes) except ValueError: raise ValueError(f'Metric type "{mode}" not supported.') diff --git a/hypyp/sync/README.md b/hypyp/sync/README.md index 5fec773..5dd6b4b 100644 --- a/hypyp/sync/README.md +++ b/hypyp/sync/README.md @@ -58,9 +58,6 @@ Arbitrary methodological decisions skew inter-brain synchronization estimates in hyperscanning-EEG studies. *Imaging Neuroscience*, 2. https://doi.org/10.1162/imag_a_00350 -**Note:** ACCorr supports hardware acceleration via `optimization` parameter. -See [Optimization Backends](#optimization-backends) below. - --- ### Coherence (`coh`) @@ -165,24 +162,110 @@ amplitude. More sensitive to high-amplitude bursts. ## Optimization Backends -ACCorr supports three computational backends via the `optimization` parameter -in `compute_sync()` or the class constructor: +All 9 metrics support multiple computational backends via the `optimization` +parameter in `compute_sync()` or the class constructor. + +### Backend Support Matrix + +| Metric | numpy | numba | torch | metal | cuda_kernel | +|--------|:-----:|:-----:|:-----:|:-----:|:-----------:| +| PLV | x | x | x | -- | x | +| CCorr | x | x | x | -- | x | +| Coh | x | x | x | -- | x | +| ImCoh | x | x | x | -- | x | +| EnvCorr| x | x | x | -- | x | +| PowCorr| x | x | x | -- | x | +| PLI | x | x | x | x | x | +| wPLI | x | x | x | x | x | +| ACCorr | x | x | x | x | x | + +### Backend Descriptions | Value | Backend | Device | Notes | |-------|---------|--------|-------| | `None` (default) | NumPy | CPU | Standard, no extra dependencies | -| `'auto'` | Best available | Auto | torch → numba → numpy | -| `'numba'` | Numba JIT | CPU | ~2× speedup; install: `poetry install --with optim_numba` | -| `'torch'` | PyTorch | GPU/CPU | ~20× speedup on GPU; install: `poetry install --with optim_torch` | +| `'auto'` | Best available | Auto | Selects best GPU backend per metric and platform | +| `'numba'` | Numba JIT | CPU | Fused single-pass kernels with `prange` parallelism | +| `'torch'` | PyTorch | GPU/CPU | Batched einsum; MPS (Apple) / CUDA (NVIDIA) / CPU | +| `'metal'` | Metal shaders | Apple GPU | Custom compute shaders for PLI, wPLI, ACCorr only | +| `'cuda_kernel'` | CuPy RawKernel | NVIDIA GPU | Custom CUDA kernels; float64 precision | + +### `optimization='auto'` — Benchmark-Driven Dispatch + +The `'auto'` mode selects the best GPU backend for each metric based on +benchmark data compiled from Mac M4 Max (131 runs) and Narval A100 (111 runs). + +**MPS (Apple Silicon):** +- Einsum metrics (PLV, CCorr, Coh, ImCoh, EnvCorr, PowCorr): torch (batched BLAS) +- Sign-based (PLI, wPLI) + ACCorr: Metal custom kernels + +**CUDA (NVIDIA):** +- All metrics: `cuda_kernel` first (pairwise computation, OOM-safe at 512+ channels), + with torch as fallback. + +The priority can be overridden per-call: +```python +get_metric('plv', optimization='auto', priority=['torch', 'cuda_kernel']) +``` + +If no GPU backend is available, `'auto'` falls back to numba, then numpy. + +### Precision + +- **CPU / CUDA (`float64`):** reference precision, `rtol=1e-9, atol=1e-10` +- **MPS / Metal (`float32`):** up to ~1e-5 difference vs CPU reference. + Sign-based metrics (PLI, wPLI) may show larger differences (`rtol=1e-2`) + near the sign discontinuity at zero. + +--- + +## Architecture -**Device priority for `'torch'` and `'auto'`:** MPS (Apple Silicon) > CUDA (NVIDIA) > CPU. -MPS and CUDA are mutually exclusive; the best available device is selected automatically. +``` +hypyp/sync/ +├── __init__.py # Registry, get_metric(), exports +├── base.py # BaseMetric, AUTO_PRIORITY, helpers +├── plv.py ... wpli.py # One file per metric (9 files) +└── kernels/ # Custom GPU kernels + ├── __init__.py # METAL_AVAILABLE, CUPY_AVAILABLE flags + ├── _metal_dispatch.py # Shared Metal pairwise dispatch + ├── _cuda_dispatch.py # Shared CUDA pairwise dispatch + ├── metal_phase.py # PLI, wPLI Metal shaders + ├── metal_accorr.py # ACCorr Metal shader + ├── cuda_phase.py # PLI, wPLI, PLV, CCorr CUDA kernels + ├── cuda_amplitude.py # Coh, ImCoh, EnvCorr, PowCorr CUDA kernels + └── cuda_accorr.py # ACCorr CUDA kernel +``` -**Precision note:** MPS uses `float32`, which may introduce numerical differences -of up to ~1e-5 compared to CPU/CUDA (`float64`). +Each metric class inherits from `BaseMetric` and implements: +- `_compute_numpy()` — always available (reference implementation) +- `_compute_numba()` — fused loop with `numba.prange` parallelism +- `_compute_torch()` — batched einsum on auto-detected device +- `_compute_metal()` — Metal shader dispatch (PLI, wPLI, ACCorr only) +- `_compute_cuda()` — CUDA RawKernel dispatch -All other metrics currently use numpy only (`optimization` parameter is accepted -but ignored for non-ACCorr metrics). +Backend selection happens at `__init__()`, dispatch at `compute()`. + +--- + +## Installation + +```bash +# Core (numpy backend always available) +pip install hypyp + +# CPU parallelism +pip install "hypyp[numba]" + +# GPU acceleration (PyTorch) +pip install "hypyp[torch]" + +# Apple Silicon Metal shaders (PLI, wPLI, ACCorr) +pip install "hypyp[metal]" + +# NVIDIA CUDA kernels (all metrics, requires CUDA 12.x) +pip install "hypyp[cupy]" +``` --- @@ -192,13 +275,20 @@ but ignored for non-ACCorr metrics). from hypyp.analyses import compute_sync # Standard (numpy) -con = compute_sync(complex_signal, 'accorr') +con = compute_sync(complex_signal, 'plv') + +# Best available GPU backend +con = compute_sync(complex_signal, 'plv', optimization='auto') + +# Specific backend +con = compute_sync(complex_signal, 'pli', optimization='metal') -# With GPU acceleration -con = compute_sync(complex_signal, 'accorr', optimization='torch') +# Custom priority +con = compute_sync(complex_signal, 'coh', optimization='auto', + priority=['torch', 'cuda_kernel']) # Direct class instantiation -from hypyp.sync import ACCorr -metric = ACCorr(optimization='auto', show_progress=True) +from hypyp.sync import get_metric +metric = get_metric('accorr', optimization='auto') con = metric.compute(complex_signal_internal, n_samp, transpose_axes) ``` diff --git a/hypyp/sync/__init__.py b/hypyp/sync/__init__.py index 5c57db2..dfca90b 100644 --- a/hypyp/sync/__init__.py +++ b/hypyp/sync/__init__.py @@ -10,7 +10,10 @@ from typing import Optional -from .base import BaseMetric, multiply_conjugate, multiply_conjugate_time, multiply_product +from .base import ( + BaseMetric, multiply_conjugate, multiply_conjugate_time, multiply_product, + multiply_conjugate_torch, multiply_conjugate_time_torch, +) from .plv import PLV from .ccorr import CCorr from .accorr import ACCorr @@ -40,6 +43,8 @@ 'multiply_conjugate', 'multiply_conjugate_time', 'multiply_product', + 'multiply_conjugate_torch', + 'multiply_conjugate_time_torch', # Metric classes 'PLV', 'CCorr', @@ -56,7 +61,8 @@ ] -def get_metric(mode: str, optimization: Optional[str] = None) -> BaseMetric: +def get_metric(mode: str, optimization: Optional[str] = None, + priority: Optional[list] = None) -> BaseMetric: """ Get a connectivity metric instance by name. @@ -66,8 +72,11 @@ def get_metric(mode: str, optimization: Optional[str] = None) -> BaseMetric: Name of the connectivity metric. One of: 'plv', 'ccorr', 'accorr', 'coh', 'imcoh', 'pli', 'wpli', 'envcorr', 'powcorr'. optimization : str, optional - Optimization strategy. Options: None, 'auto', 'numba', 'torch'. - See BaseMetric for fallback behavior. + Optimization strategy. Options: None, 'auto', 'numba', 'torch', + 'metal', 'cuda_kernel'. See BaseMetric for fallback behavior. + priority : list of str, optional + Custom backend priority for ``'auto'`` mode. Overrides the default + ``AUTO_PRIORITY`` table. Example: ``['metal', 'torch', 'numba']``. Returns ------- @@ -82,12 +91,13 @@ def get_metric(mode: str, optimization: Optional[str] = None) -> BaseMetric: Examples -------- >>> from hypyp.sync import get_metric - >>> accorr = get_metric('accorr', optimization='torch') - >>> result = accorr.compute(complex_signal, n_samp, transpose_axes) + >>> plv = get_metric('plv', optimization='auto') # benchmark-driven + >>> pli = get_metric('pli', optimization='auto', + ... priority=['numba', 'metal']) # custom priority """ mode_lower = mode.lower() if mode_lower not in METRICS: available = ', '.join(METRICS.keys()) raise ValueError(f"Unknown metric mode '{mode}'. Available: {available}") - return METRICS[mode_lower](optimization=optimization) + return METRICS[mode_lower](optimization=optimization, priority=priority) diff --git a/hypyp/sync/accorr.py b/hypyp/sync/accorr.py index fb0506f..08a3831 100644 --- a/hypyp/sync/accorr.py +++ b/hypyp/sync/accorr.py @@ -66,8 +66,9 @@ class ACCorr(BaseMetric): name = "accorr" def __init__(self, optimization: Optional[str] = None, + priority: Optional[list] = None, show_progress: bool = True): - super().__init__(optimization) + super().__init__(optimization, priority) self.show_progress = show_progress def compute(self, complex_signal: np.ndarray, n_samp: int, @@ -89,13 +90,29 @@ def compute(self, complex_signal: np.ndarray, n_samp: int, con : np.ndarray ACCorr connectivity matrix with shape (n_epoch, n_freq, 2*n_ch, 2*n_ch). """ - if self._backend == 'numba': + if self._backend == 'metal': + return self._compute_metal(complex_signal, n_samp, transpose_axes) + elif self._backend == 'cuda_kernel': + return self._compute_cuda(complex_signal, n_samp, transpose_axes) + elif self._backend == 'numba': return self._compute_numba(complex_signal, n_samp, transpose_axes) elif self._backend == 'torch': return self._compute_torch(complex_signal, n_samp, transpose_axes) else: return self._compute_numpy(complex_signal, n_samp, transpose_axes) + def _compute_metal(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """Metal compute shader for ACCorr on Apple Silicon GPU.""" + from .kernels.metal_accorr import accorr_metal + return accorr_metal(complex_signal) + + def _compute_cuda(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """CUDA kernel for ACCorr on NVIDIA GPU.""" + from .kernels.cuda_accorr import accorr_cuda + return accorr_cuda(complex_signal) + def _compute_numpy(self, complex_signal: np.ndarray, n_samp: int, transpose_axes: tuple) -> np.ndarray: """ @@ -161,8 +178,8 @@ def _compute_numba(self, complex_signal: np.ndarray, n_samp: int, """ Numba-optimized implementation of ACCorr with precompute. - Uses numba JIT compilation for the denominator loop. - Note: parallelization is currently disabled due to a dependency conflict. + Uses numba JIT compilation with prange parallelization for the + denominator loop. """ n_epochs, n_freq, n_ch_total, n_times = complex_signal.shape @@ -194,6 +211,10 @@ def _compute_numba(self, complex_signal: np.ndarray, n_samp: int, return con + # Memory threshold for vectorized denominator (bytes). If the 5D tensor + # (E, F, C, C, T) would exceed this, fall back to the loop-based approach. + _VRAM_THRESHOLD = 2 * 1024**3 # 2 GB + def _compute_torch(self, complex_signal: np.ndarray, n_samp: int, transpose_axes: tuple) -> np.ndarray: """ @@ -201,15 +222,21 @@ def _compute_torch(self, complex_signal: np.ndarray, n_samp: int, Uses torch tensor operations on the resolved device (cpu/mps/cuda). MPS uses float32 precision; cpu/cuda uses float64. + + The denominator is fully vectorized via broadcasting when the + intermediate 5D tensor fits in memory (< _VRAM_THRESHOLD). Otherwise, + falls back to a per-pair loop on device. """ device = self._device if device == 'mps': float_type = torch.float32 complex_type = torch.complex64 + bytes_per_elem = 4 else: float_type = torch.float64 complex_type = torch.complex128 + bytes_per_elem = 8 complex_tensor = torch.from_numpy(complex_signal).to(device=device, dtype=complex_type) n_epochs, n_freq, n_ch_total, n_times = complex_tensor.shape @@ -218,11 +245,14 @@ def _compute_torch(self, complex_signal: np.ndarray, n_samp: int, z = complex_tensor / torch.abs(complex_tensor) c, s = z.real, z.imag + # Factorized: 4 einsum shared between cross_conj and cross_prod formula = 'efit,efjt->efij' - cross_conj = (torch.einsum(formula, c, c) + torch.einsum(formula, s, s)) - 1j * \ - (torch.einsum(formula, c, s) - torch.einsum(formula, s, c)) - cross_prod = (torch.einsum(formula, c, c) - torch.einsum(formula, s, s)) + 1j * \ - (torch.einsum(formula, c, s) + torch.einsum(formula, s, c)) + cc = torch.einsum(formula, c, c) + ss = torch.einsum(formula, s, s) + cs = torch.einsum(formula, c, s) + sc = torch.einsum(formula, s, c) + cross_conj = (cc + ss) - 1j * (cs - sc) + cross_prod = (cc - ss) + 1j * (cs + sc) r_minus = torch.abs(cross_conj) r_plus = torch.abs(cross_prod) @@ -235,8 +265,51 @@ def _compute_torch(self, complex_signal: np.ndarray, n_samp: int, n_adj_all = -0.5 * (mean_diff_all - mean_sum_all) m_adj_all = mean_diff_all + n_adj_all - # Denominator - loop on device + # Denominator — choose vectorized or loop based on memory angle = torch.angle(complex_tensor) + tensor_5d_bytes = n_epochs * n_freq * n_ch_total * n_ch_total * n_times * bytes_per_elem + use_vectorized = tensor_5d_bytes < self._VRAM_THRESHOLD + + if use_vectorized: + den = self._den_vectorized(angle, m_adj_all, n_adj_all, device, float_type) + else: + den = self._den_loop(angle, m_adj_all, n_adj_all, device, float_type, + n_epochs, n_freq, n_ch_total) + + den = torch.where(den == 0, torch.ones_like(den), den) + con = num / den + + return con.cpu().numpy() + + def _den_vectorized(self, angle, m_adj_all, n_adj_all, device, float_type): + """ + Fully vectorized denominator via broadcasting. + + Broadcasts angle (E,F,C,T) against m_adj_all (E,F,C,C) to compute + sin(angle_i - m_adj_{ij}) for all pairs simultaneously. + + Shape flow: + angle[:,:,:,None,:] - m_adj_all[:,:,:,:,None] -> (E, F, C, C, T) + sin -> square -> sum over T -> sqrt -> 2 * sqrt(prod) + """ + # angle: (E, F, C, T) -> (E, F, C, 1, T) + # m_adj_all: (E, F, C, C) -> (E, F, C, C, 1) + x_sin = torch.sin(angle.unsqueeze(3) - m_adj_all.unsqueeze(-1)) # (E,F,C,C,T) + y_sin = torch.sin(angle.unsqueeze(2) - n_adj_all.unsqueeze(-1)) # (E,F,C,C,T) + + sum_x2 = torch.sum(x_sin ** 2, dim=-1) # (E, F, C, C) + sum_y2 = torch.sum(y_sin ** 2, dim=-1) # (E, F, C, C) + + return 2.0 * torch.sqrt(sum_x2 * sum_y2) + + def _den_loop(self, angle, m_adj_all, n_adj_all, device, float_type, + n_epochs, n_freq, n_ch_total): + """ + Loop-based denominator (fallback for large data). + + Iterates over channel pairs when the vectorized 5D tensor + would exceed the VRAM threshold. + """ den = torch.zeros((n_epochs, n_freq, n_ch_total, n_ch_total), device=device, dtype=float_type) @@ -262,45 +335,37 @@ def _compute_torch(self, complex_signal: np.ndarray, n_samp: int, pbar.update(1) pbar.close() - - den = torch.where(den == 0, torch.ones_like(den), den) - con = num / den - - return con.cpu().numpy() + return den # Numba JIT-compiled helper (defined at module level for caching) if NUMBA_AVAILABLE: - # TODO(@m2march): research why parallelization is not working - @njit(parallel=False, cache=True) + @njit(parallel=True, cache=True) def _accorr_den_numba(n_epochs, n_freq, n_ch_total, angle, m_adj_all, n_adj_all): - """Numba JIT-compiled denominator calculation for accorr.""" - den = np.zeros((n_epochs, n_freq, n_ch_total, n_ch_total)) - - for i in range(den.shape[2]): - for j in range(i, den.shape[3]): - alpha1 = angle[:, :, i, :] - alpha2 = angle[:, :, j, :] - - m_adj = m_adj_all[:, :, i, j] - n_adj = n_adj_all[:, :, i, j] - - x = alpha1.copy() - for xi in range(x.shape[0]): - for xj in range(x.shape[1]): - for xk in range(x.shape[2]): - x[xi, xj, xk] -= m_adj[xi, xj] - x_sin = np.sin(x) + """ + Numba JIT-compiled denominator calculation for accorr. - y = alpha2.copy() - for yi in range(y.shape[0]): - for yj in range(y.shape[1]): - for yk in range(y.shape[2]): - y[yi, yj, yk] -= n_adj[yi, yj] - y_sin = np.sin(y) + Uses prange for parallel iteration over channel pairs. The inner + subtraction uses explicit loops (numba-compatible) instead of + .copy() + broadcasting which caused allocation issues with prange. + """ + den = np.zeros((n_epochs, n_freq, n_ch_total, n_ch_total)) - den_ij = 2 * np.sqrt(np.sum(x_sin**2, axis=2) * np.sum(y_sin**2, axis=2)) - den[:, :, i, j] = den_ij - den[:, :, j, i] = den_ij + for i in prange(n_ch_total): + for j in range(i, n_ch_total): + # Compute sum of sin^2 for x and y directly, no temp arrays + for ei in range(n_epochs): + for fi in range(n_freq): + m = m_adj_all[ei, fi, i, j] + n = n_adj_all[ei, fi, i, j] + sum_x2 = 0.0 + sum_y2 = 0.0 + for ti in range(angle.shape[3]): + sx = np.sin(angle[ei, fi, i, ti] - m) + sy = np.sin(angle[ei, fi, j, ti] - n) + sum_x2 += sx * sx + sum_y2 += sy * sy + den[ei, fi, i, j] = 2.0 * np.sqrt(sum_x2 * sum_y2) + den[ei, fi, j, i] = den[ei, fi, i, j] return den diff --git a/hypyp/sync/base.py b/hypyp/sync/base.py index 5171385..160b064 100644 --- a/hypyp/sync/base.py +++ b/hypyp/sync/base.py @@ -35,6 +35,48 @@ except ImportError: NUMBA_AVAILABLE = False +# Custom kernel backends +from .kernels import METAL_AVAILABLE, CUPY_AVAILABLE + + +# --------------------------------------------------------------------------- +# Benchmark-driven GPU backend priority for optimization='auto' +# --------------------------------------------------------------------------- +# Compiled from Mac M4 Max (131 rows) and Narval A100 (111 rows) benchmarks. +# Format: {metric_name: {platform: [gpu_backend_1, gpu_backend_2]}} +# First available GPU backend in the list wins. +# +# 'auto' selects the best GPU backend only. Users choose CPU strategies +# explicitly: optimization=None (numpy) or optimization='numba'. +# +# The priority can be overridden per-call via the `priority` parameter: +# get_metric('plv', optimization='auto', priority=['metal', 'torch']) +# +# Rationale: +# MPS — einsum metrics: torch wins (batched matrix ops via Apple MPS). +# sign-based/accorr: Metal custom kernels win (sign() and circular +# correlation are not vectorizable; torch OOMs at ≥512ch for PLI/wPLI). +# No Metal kernels for einsum metrics (torch_mps dominates at all scales). +# CUDA — cuda_kernel first for all metrics: torch_cuda is faster at small/ +# medium scale but OOMs at realistic_hd (512ch) due to large +# intermediate tensors. cuda_kernel computes pairwise without +# materializing the full output tensor. +AUTO_PRIORITY = { + # einsum metrics — torch wins on MPS, cuda_kernel safe-first on CUDA + # (torch OOMs at ≥512ch on CUDA; cuda_kernel computes pairwise) + 'plv': {'mps': ['torch'], 'cuda': ['cuda_kernel', 'torch']}, + 'ccorr': {'mps': ['torch'], 'cuda': ['cuda_kernel', 'torch']}, + 'coh': {'mps': ['torch'], 'cuda': ['cuda_kernel', 'torch']}, + 'imcoh': {'mps': ['torch'], 'cuda': ['cuda_kernel', 'torch']}, + 'envcorr': {'mps': ['torch'], 'cuda': ['cuda_kernel', 'torch']}, + 'powcorr': {'mps': ['torch'], 'cuda': ['cuda_kernel', 'torch']}, + # sign-based — custom kernels beat torch on both platforms + 'pli': {'mps': ['metal', 'torch'], 'cuda': ['cuda_kernel', 'torch']}, + 'wpli': {'mps': ['metal', 'torch'], 'cuda': ['cuda_kernel', 'torch']}, + # accorr — Metal wins on MPS (circular correlation), cuda_kernel safe on CUDA + 'accorr': {'mps': ['metal', 'torch'], 'cuda': ['cuda_kernel', 'torch']}, +} + def multiply_conjugate(real: np.ndarray, imag: np.ndarray, transpose_axes: tuple) -> np.ndarray: """ @@ -127,6 +169,56 @@ def multiply_product(real: np.ndarray, imag: np.ndarray, transpose_axes: tuple) return product +def multiply_conjugate_torch(c, s): + """ + Compute z * conj(z) using torch tensors, collapsing time dimension. + + Torch equivalent of :func:`multiply_conjugate`. Uses the einsum convention + ``e=epoch, f=freq, i=ch_row, j=ch_col, t=time``. + + Parameters + ---------- + c : torch.Tensor + Real part, shape (E, F, C, T). + s : torch.Tensor + Imaginary part, shape (E, F, C, T). + + Returns + ------- + torch.Tensor + Complex product, shape (E, F, C, C). + """ + formula = 'efit,efjt->efij' + import torch + return (torch.einsum(formula, c, c) + torch.einsum(formula, s, s)) - 1j * \ + (torch.einsum(formula, c, s) - torch.einsum(formula, s, c)) + + +def multiply_conjugate_time_torch(c, s): + """ + Compute z * conj(z) using torch tensors, preserving time dimension. + + Torch equivalent of :func:`multiply_conjugate_time`. Produces a 5D tensor + ``(E, F, C, C, T)`` — can be very large for high channel counts. + + Parameters + ---------- + c : torch.Tensor + Real part, shape (E, F, C, T). + s : torch.Tensor + Imaginary part, shape (E, F, C, T). + + Returns + ------- + torch.Tensor + Complex product, shape (E, F, C, C, T). + """ + formula = 'efit,efjt->efijt' + import torch + return (torch.einsum(formula, c, c) + torch.einsum(formula, s, s)) - 1j * \ + (torch.einsum(formula, c, s) - torch.einsum(formula, s, c)) + + class BaseMetric(ABC): """ Abstract base class for connectivity metrics. @@ -153,12 +245,17 @@ class BaseMetric(ABC): name: str = "base" - def __init__(self, optimization: Optional[str] = None): + def __init__(self, optimization: Optional[str] = None, + priority: Optional[list] = None): self.optimization = optimization - self._backend, self._device = self._resolve_optimization(optimization) + self._priority = priority + self._backend, self._device = self._resolve_optimization( + optimization, priority + ) - @staticmethod - def _resolve_optimization(optimization: Optional[str] = None) -> tuple: + @classmethod + def _resolve_optimization(cls, optimization: Optional[str] = None, + priority: Optional[list] = None) -> tuple: """ Resolves an optimization value to (backend, device). @@ -171,38 +268,46 @@ def _resolve_optimization(optimization: Optional[str] = None) -> tuple: Requested optimization strategy: - ``None``: standard numpy, no acceleration (default). - - ``'auto'``: best available backend — tries torch first, then - numba, then falls back to numpy. No warning is emitted. + - ``'auto'``: best available backend, selected per-metric from + the ``AUTO_PRIORITY`` table (compiled from benchmarks). + See ``_resolve_auto`` for details. - ``'numba'``: JIT-compiled loops via numba. Falls back to numpy with a UserWarning if numba is not installed. - ``'torch'``: PyTorch tensors with auto-detected GPU (see ``_resolve_torch`` for device priority). Falls back to numpy with a UserWarning if torch is not installed. + - ``'metal'``: Apple Metal compute shaders. Falls back to numpy + with a UserWarning if PyObjC Metal is not available. + - ``'cuda_kernel'``: Custom CUDA kernels via CuPy. Falls back + to numpy with a UserWarning if CuPy is not available. + priority : list of str, optional + Custom backend priority list for ``'auto'`` mode. Overrides + the default ``AUTO_PRIORITY`` table for this call. + Example: ``['metal', 'torch', 'numba']``. Returns ------- backend : str - One of ``'numpy'``, ``'numba'``, ``'torch'``. + One of ``'numpy'``, ``'numba'``, ``'torch'``, ``'metal'``, + ``'cuda_kernel'``. device : str One of ``'cpu'``, ``'mps'``, ``'cuda'``. Notes ----- - Fallback cascade for ``'auto'``: - torch (best available device) → numba → numpy + Fallback cascade for ``'auto'`` (per-metric, per-platform): + Iterates ``AUTO_PRIORITY[metric][platform]`` and returns the + first available backend. Falls back to numba → numpy if no + GPU backend is available. - Fallback cascade for ``'torch'`` or ``'numba'`` when unavailable: + Fallback cascade for explicit backends when unavailable: requested backend → numpy (with UserWarning) """ if optimization is None: return 'numpy', 'cpu' if optimization == 'auto': - if TORCH_AVAILABLE: - return BaseMetric._resolve_torch() - if NUMBA_AVAILABLE: - return 'numba', 'cpu' - return 'numpy', 'cpu' + return cls._resolve_auto(priority) if optimization == 'numba': if NUMBA_AVAILABLE: @@ -216,7 +321,7 @@ def _resolve_optimization(optimization: Optional[str] = None) -> tuple: if optimization == 'torch': if TORCH_AVAILABLE: - return BaseMetric._resolve_torch() + return cls._resolve_torch() warnings.warn( "torch not installed, falling back to numpy. " "Install with: poetry install --with optim_torch", @@ -224,10 +329,93 @@ def _resolve_optimization(optimization: Optional[str] = None) -> tuple: ) return 'numpy', 'cpu' + if optimization == 'metal': + if METAL_AVAILABLE: + return 'metal', 'mps' + warnings.warn( + "PyObjC Metal not available, falling back to numpy. " + "Install with: pip install pyobjc-framework-Metal", + UserWarning, stacklevel=3 + ) + return 'numpy', 'cpu' + + if optimization == 'cuda_kernel': + if CUPY_AVAILABLE: + return 'cuda_kernel', 'cuda' + warnings.warn( + "CuPy not available, falling back to numpy. " + "Install with: pip install cupy-cuda12x", + UserWarning, stacklevel=3 + ) + return 'numpy', 'cpu' + raise ValueError( f"Unknown optimization '{optimization}'. " - f"Options: None, 'auto', 'numba', 'torch'" + f"Options: None, 'auto', 'numba', 'torch', 'metal', 'cuda_kernel'" + ) + + @classmethod + def _resolve_auto(cls, priority: Optional[list] = None) -> tuple: + """ + Benchmark-driven backend selection, per metric and platform. + + Uses the ``AUTO_PRIORITY`` table compiled from Mac M4 Max and + Narval A100 benchmarks. Iterates the priority list and returns + the first available backend. + + Parameters + ---------- + priority : list of str, optional + Custom priority list overriding ``AUTO_PRIORITY`` for this call. + + Returns + ------- + backend : str + Selected backend name. + device : str + Associated device (``'cpu'``, ``'mps'``, or ``'cuda'``). + + Notes + ----- + Platform detection: MPS → 'mps', CUDA → 'cuda', else 'cpu'. + On CPU-only machines, warns and falls back to numba → numpy. + """ + if MPS_AVAILABLE: + platform = 'mps' + elif CUDA_AVAILABLE: + platform = 'cuda' + else: + # No GPU — warn and fall back to CPU + warnings.warn( + "No GPU available. optimization='auto' selects the best GPU " + "backend. Use optimization='numba' for CPU parallelism or " + "optimization=None for numpy.", + UserWarning, stacklevel=4 + ) + if NUMBA_AVAILABLE: + return 'numba', 'cpu' + return 'numpy', 'cpu' + + if priority is None: + priority = AUTO_PRIORITY.get(cls.name, {}).get(platform, []) + + for backend in priority: + if backend == 'torch' and TORCH_AVAILABLE: + return cls._resolve_torch() + if backend == 'metal' and METAL_AVAILABLE: + return 'metal', 'mps' + if backend == 'cuda_kernel' and CUPY_AVAILABLE: + return 'cuda_kernel', 'cuda' + + # No GPU backend from priority list available — fall back + warnings.warn( + f"No GPU backend available for {cls.name!r} on platform " + f"'{platform}'. Falling back to CPU.", + UserWarning, stacklevel=4 ) + if NUMBA_AVAILABLE: + return 'numba', 'cpu' + return 'numpy', 'cpu' @staticmethod def _resolve_torch() -> tuple: diff --git a/hypyp/sync/ccorr.py b/hypyp/sync/ccorr.py index 8806545..78e02b8 100644 --- a/hypyp/sync/ccorr.py +++ b/hypyp/sync/ccorr.py @@ -6,9 +6,14 @@ """ import numpy as np -from scipy.stats import circmean -from .base import BaseMetric +from .base import BaseMetric, TORCH_AVAILABLE, NUMBA_AVAILABLE + +if TORCH_AVAILABLE: + import torch + +if NUMBA_AVAILABLE: + from numba import njit, prange class CCorr(BaseMetric): @@ -46,23 +51,183 @@ def compute(self, complex_signal: np.ndarray, n_samp: int, con : np.ndarray CCorr connectivity matrix with shape (n_epoch, n_freq, 2*n_ch, 2*n_ch). """ + if self._backend == 'cuda_kernel': + return self._compute_cuda(complex_signal, n_samp, transpose_axes) + elif self._backend == 'torch': + return self._compute_torch(complex_signal, n_samp, transpose_axes) + elif self._backend == 'numba': + return self._compute_numba(complex_signal, n_samp, transpose_axes) return self._compute_numpy(complex_signal, n_samp, transpose_axes) + + def _compute_cuda(self, complex_signal, n_samp, transpose_axes): + """CUDA kernel for CCorr.""" + from .kernels.cuda_phase import ccorr_cuda + return ccorr_cuda(complex_signal) def _compute_numpy(self, complex_signal: np.ndarray, n_samp: int, transpose_axes: tuple) -> np.ndarray: - """NumPy implementation of CCorr.""" + """ + NumPy implementation of CCorr. + + Uses inline circular mean (atan2(mean(sin), mean(cos))) instead of + scipy.stats.circmean to remove the scipy dependency from this module. + Mathematically identical: circmean(x, high=pi, low=-pi) == atan2(mean(sin(x)), mean(cos(x))). + """ n_epoch = complex_signal.shape[0] n_freq = complex_signal.shape[1] n_ch_total = complex_signal.shape[2] - + angle = np.angle(complex_signal) - mu_angle = circmean(angle, high=np.pi, low=-np.pi, axis=3).reshape( - n_epoch, n_freq, n_ch_total, 1 - ) + # Circular mean: atan2(mean(sin(angle)), mean(cos(angle))) + # Equivalent to scipy.stats.circmean(angle, high=pi, low=-pi, axis=3) + mu_angle = np.arctan2( + np.mean(np.sin(angle), axis=3), + np.mean(np.cos(angle), axis=3), + ).reshape(n_epoch, n_freq, n_ch_total, 1) angle = np.sin(angle - mu_angle) formula = 'nilm,nimk->nilk' con = np.abs(np.einsum(formula, angle, angle.transpose(transpose_axes)) / - np.sqrt(np.einsum('nil,nik->nilk', np.sum(angle ** 2, axis=3), + np.sqrt(np.einsum('nil,nik->nilk', np.sum(angle ** 2, axis=3), np.sum(angle ** 2, axis=3)))) return con + + def _compute_numba(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """ + Numba JIT implementation of CCorr using angle-free reformulation. + + Works with cos(φ)/sin(φ) directly, avoiding all transcendental + functions inside the JIT-compiled loops. Uses prange for + parallelization across epochs and symmetry exploitation (upper triangle). + """ + phase = complex_signal / np.abs(complex_signal) + c = np.real(phase) + s = np.imag(phase) + return _ccorr_numba_kernel(c, s) + + def _compute_torch(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """ + PyTorch implementation of CCorr using angle-free reformulation. + + Instead of extracting phase angles (torch.angle + arctan2 + sin), + works directly with cos(φ) and sin(φ) from the unit-phase signal. + + The circular centering sin(φ - μ) is reformulated as: + d(t) = sin(φ(t)) * C̄ - cos(φ(t)) * S̄ + where C̄ = mean(cos(φ)), S̄ = mean(sin(φ)). + The normalization factor R = √(C̄² + S̄²) cancels in the correlation. + + This eliminates all transcendental functions (angle, arctan2, sin) + after the initial phase normalization, improving MPS float32 precision. + """ + device = self._device + complex_type = torch.complex64 if device == 'mps' else torch.complex128 + + sig = torch.from_numpy(complex_signal).to(device=device, dtype=complex_type) + + # Unit-phase signal: same as PLV + phase = sig / torch.abs(sig) + c, s = phase.real, phase.imag # cos(φ), sin(φ) + + # Circular mean components (no atan2) + C_bar = torch.mean(c, dim=3, keepdim=True) # (E, F, C, 1) + S_bar = torch.mean(s, dim=3, keepdim=True) # (E, F, C, 1) + + # Centered signal: d(t) = s(t)*C_bar - c(t)*S_bar + # R factor cancels in correlation, no division needed + d = s * C_bar - c * S_bar # (E, F, C, T) + + # Correlation via einsum + formula = 'efit,efjt->efij' + num = torch.einsum(formula, d, d) + sum_sq = torch.sum(d ** 2, dim=3) + den = torch.sqrt(torch.einsum('efi,efj->efij', sum_sq, sum_sq)) + + con = torch.abs(num / den) + return con.cpu().numpy() + + def _compute_torch_cpu_circmean(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """ + Hybrid approach: circular mean in float64 on CPU, correlation on GPU. + + Computes the precision-sensitive circular mean (arctan2) in float64 + on CPU, then transfers the centered signal to GPU for the einsum. + Kept for comparison with the angle-free reformulation. + """ + device = self._device + float_type = torch.float32 if device == 'mps' else torch.float64 + + # Step 1: Circular mean in float64 on CPU (precision-critical) + angle = np.angle(complex_signal) + mu_angle = np.arctan2( + np.mean(np.sin(angle), axis=3), + np.mean(np.cos(angle), axis=3), + ).reshape(complex_signal.shape[0], complex_signal.shape[1], + complex_signal.shape[2], 1) + centered = np.sin(angle - mu_angle) # float64, precise + + # Step 2: Transfer centered signal to GPU for einsum + d = torch.from_numpy(centered).to(device=device, dtype=float_type) + + formula = 'efit,efjt->efij' + num = torch.einsum(formula, d, d) + sum_sq = torch.sum(d ** 2, dim=3) + den = torch.sqrt(torch.einsum('efi,efj->efij', sum_sq, sum_sq)) + + con = torch.abs(num / den) + return con.cpu().numpy() + + +# Numba JIT kernel (module-level for caching) +if NUMBA_AVAILABLE: + @njit(parallel=True, cache=True) + def _ccorr_numba_kernel(c, s): + """ + Angle-free CCorr kernel: no transcendental functions inside loops. + + Uses d_i(t) = s_i(t)*C_bar_i - c_i(t)*S_bar_i for circular centering. + The normalization factor R cancels in the correlation. + Exploits symmetry: CCorr(i,j) == CCorr(j,i). + """ + n_ep, n_freq, n_ch, n_t = c.shape + con = np.zeros((n_ep, n_freq, n_ch, n_ch)) + + for e in prange(n_ep): + for f in range(n_freq): + # Pre-compute mean(cos) and mean(sin) per channel + C_bar = np.zeros(n_ch) + S_bar = np.zeros(n_ch) + for ch in range(n_ch): + c_sum = 0.0 + s_sum = 0.0 + for t in range(n_t): + c_sum += c[e, f, ch, t] + s_sum += s[e, f, ch, t] + C_bar[ch] = c_sum / n_t + S_bar[ch] = s_sum / n_t + + # Correlation for upper triangle + diagonal + for i in range(n_ch): + for j in range(i, n_ch): + num = 0.0 + den_i = 0.0 + den_j = 0.0 + for t in range(n_t): + # d(t) = s(t)*C_bar - c(t)*S_bar + di = s[e, f, i, t] * C_bar[i] - c[e, f, i, t] * S_bar[i] + dj = s[e, f, j, t] * C_bar[j] - c[e, f, j, t] * S_bar[j] + num += di * dj + den_i += di * di + den_j += dj * dj + denom = np.sqrt(den_i * den_j) + if denom > 0: + val = np.abs(num) / denom + else: + val = 0.0 + con[e, f, i, j] = val + con[e, f, j, i] = val # symmetry + + return con diff --git a/hypyp/sync/coh.py b/hypyp/sync/coh.py index e91cb97..516f610 100644 --- a/hypyp/sync/coh.py +++ b/hypyp/sync/coh.py @@ -7,7 +7,14 @@ import numpy as np -from .base import BaseMetric, multiply_conjugate +from .base import BaseMetric, multiply_conjugate, TORCH_AVAILABLE, NUMBA_AVAILABLE + +if TORCH_AVAILABLE: + import torch + from .base import multiply_conjugate_torch + +if NUMBA_AVAILABLE: + from numba import njit, prange class Coh(BaseMetric): @@ -49,8 +56,19 @@ def compute(self, complex_signal: np.ndarray, n_samp: int, con : np.ndarray Coherence connectivity matrix with shape (n_epoch, n_freq, 2*n_ch, 2*n_ch). """ + if self._backend == 'cuda_kernel': + return self._compute_cuda(complex_signal, n_samp, transpose_axes) + elif self._backend == 'torch': + return self._compute_torch(complex_signal, n_samp, transpose_axes) + elif self._backend == 'numba': + return self._compute_numba(complex_signal, n_samp, transpose_axes) return self._compute_numpy(complex_signal, n_samp, transpose_axes) - + + def _compute_cuda(self, complex_signal, n_samp, transpose_axes): + """CUDA kernel for Coherence.""" + from .kernels.cuda_amplitude import coh_cuda + return coh_cuda(complex_signal) + def _compute_numpy(self, complex_signal: np.ndarray, n_samp: int, transpose_axes: tuple) -> np.ndarray: """NumPy implementation of Coherence.""" @@ -61,3 +79,87 @@ def _compute_numpy(self, complex_signal: np.ndarray, n_samp: int, con = np.abs(dphi) / np.sqrt(np.einsum('nil,nik->nilk', np.nansum(amp, axis=3), np.nansum(amp, axis=3))) return con + + def _compute_numba(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """ + Numba JIT implementation of Coherence with parallel epoch processing. + + Fuses cross-spectrum and power normalization into a single loop pass. + Accumulates numerator (cross-spectrum) and denominator (power) in + CPU registers — zero intermediate tensor allocations. + """ + c = np.real(complex_signal) + s = np.imag(complex_signal) + return _coh_numba_kernel(c, s) + + def _compute_torch(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """ + PyTorch implementation of Coherence. + + Uses multiply_conjugate_torch for the cross-spectrum numerator + and torch.einsum for the power normalization denominator. + MPS uses float32; CPU/CUDA uses float64. + """ + device = self._device + complex_type = torch.complex64 if device == 'mps' else torch.complex128 + + sig = torch.from_numpy(complex_signal).to(device=device, dtype=complex_type) + c, s = sig.real, sig.imag + + # Cross-spectrum: sum_t(X_i * conj(X_j)) — contracts time dim + dphi = multiply_conjugate_torch(c, s) + + # Power normalization: sqrt(sum|X_i|² * sum|X_j|²) + amp = torch.abs(sig) ** 2 + power = torch.nansum(amp, dim=3) + den = torch.sqrt(torch.einsum('efi,efj->efij', power, power)) + + con = torch.abs(dphi) / den + return con.cpu().numpy() + + +# Numba JIT kernel (module-level for caching) +if NUMBA_AVAILABLE: + @njit(parallel=True, cache=True) + def _coh_numba_kernel(c, s): + """ + Fused Coherence: cross-spectrum + power normalization in one pass. + + For each (epoch, freq, i, j): + cross = sum_t (c_i*c_j + s_i*s_j) + i*(s_i*c_j - c_i*s_j) + pow_i = sum_t (c_i² + s_i²) + pow_j = sum_t (c_j² + s_j²) + coh = |cross| / sqrt(pow_i * pow_j) + """ + n_ep, n_freq, n_ch, n_t = c.shape + con = np.zeros((n_ep, n_freq, n_ch, n_ch)) + + for e in prange(n_ep): + for f in range(n_freq): + # Pre-compute power per channel + power = np.zeros(n_ch) + for ch in range(n_ch): + p = 0.0 + for t in range(n_t): + p += c[e, f, ch, t] ** 2 + s[e, f, ch, t] ** 2 + power[ch] = p + + # Cross-spectrum for all pairs + for i in range(n_ch): + for j in range(i, n_ch): + re_sum = 0.0 + im_sum = 0.0 + for t in range(n_t): + re_sum += c[e, f, i, t] * c[e, f, j, t] + s[e, f, i, t] * s[e, f, j, t] + im_sum += s[e, f, i, t] * c[e, f, j, t] - c[e, f, i, t] * s[e, f, j, t] + denom = np.sqrt(power[i] * power[j]) + if denom > 0: + val = np.sqrt(re_sum ** 2 + im_sum ** 2) / denom + else: + val = 0.0 + con[e, f, i, j] = val + con[e, f, j, i] = val # symmetry + + return con diff --git a/hypyp/sync/envelope_corr.py b/hypyp/sync/envelope_corr.py index 658313a..33f994a 100644 --- a/hypyp/sync/envelope_corr.py +++ b/hypyp/sync/envelope_corr.py @@ -7,55 +7,72 @@ import numpy as np -from .base import BaseMetric +from .base import BaseMetric, TORCH_AVAILABLE, NUMBA_AVAILABLE + +if TORCH_AVAILABLE: + import torch + +if NUMBA_AVAILABLE: + from numba import njit, prange class EnvCorr(BaseMetric): """ Envelope Correlation connectivity metric. - + Envelope Correlation measures the correlation between the amplitude envelopes of two signals across time. - + Mathematical formulation: EnvCorr = correlation(|X|, |Y|) over time samples - + The implementation normalizes the amplitudes by subtracting the mean and dividing by the product of standard deviations. - + References ---------- Hipp, J. F., Hawellek, D. J., Corbetta, M., Siegel, M., & Engel, A. K. (2012). Large-scale cortical correlation structure of spontaneous oscillatory activity. Nature Neuroscience, 15(6), 884-890. """ - + name = "envcorr" - + def compute(self, complex_signal: np.ndarray, n_samp: int, transpose_axes: tuple) -> np.ndarray: """ Compute Envelope Correlation. - + Parameters ---------- complex_signal : np.ndarray Complex analytic signals with shape (n_epochs, n_freq, 2*n_channels, n_times). - + n_samp : int Number of time samples. - + transpose_axes : tuple Axes to transpose for matrix multiplication. - + Returns ------- con : np.ndarray Envelope Correlation connectivity matrix with shape (n_epoch, n_freq, 2*n_ch, 2*n_ch). """ + if self._backend == 'cuda_kernel': + return self._compute_cuda(complex_signal, n_samp, transpose_axes) + elif self._backend == 'torch': + return self._compute_torch(complex_signal, n_samp, transpose_axes) + elif self._backend == 'numba': + return self._compute_numba(complex_signal, n_samp, transpose_axes) return self._compute_numpy(complex_signal, n_samp, transpose_axes) - + + def _compute_cuda(self, complex_signal, n_samp, transpose_axes): + """CUDA kernel for Envelope Correlation.""" + from .kernels.cuda_amplitude import envcorr_cuda + return envcorr_cuda(complex_signal) + def _compute_numpy(self, complex_signal: np.ndarray, n_samp: int, transpose_axes: tuple) -> np.ndarray: """NumPy implementation of Envelope Correlation.""" @@ -66,3 +83,90 @@ def _compute_numpy(self, complex_signal: np.ndarray, n_samp: int, con = np.einsum('nilm,nimk->nilk', env, env.transpose(transpose_axes)) / \ np.sqrt(np.einsum('nil,nik->nilk', np.sum(env ** 2, axis=3), np.sum(env ** 2, axis=3))) return con + + def _compute_numba(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """ + Numba JIT implementation of Envelope Correlation. + + Fuses mean-centering, Pearson numerator, and denominator into a + single loop pass with parallel epoch processing. Zero intermediate + tensor allocations — accumulates in CPU registers. + """ + env = np.abs(complex_signal) + return _envcorr_numba_kernel(env) + + def _compute_torch(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """ + PyTorch implementation of Envelope Correlation. + + Extracts the real-valued amplitude envelope immediately, then + computes Pearson correlation entirely in float32 (MPS) or + float64 (CPU/CUDA). No complex arithmetic on GPU. + """ + device = self._device + float_type = torch.float32 if device == 'mps' else torch.float64 + complex_type = torch.complex64 if device == 'mps' else torch.complex128 + + sig = torch.from_numpy(complex_signal).to(device=device, dtype=complex_type) + env = torch.abs(sig).to(dtype=float_type) # real envelope + del sig # free complex tensor + + # Center the envelope + mu = torch.mean(env, dim=3, keepdim=True) + env = env - mu + + # Pearson numerator: sum_t(env_i(t) * env_j(t)) + num = torch.einsum('efit,efjt->efij', env, env) + + # Denominator: sqrt(sum_t(env_i²) * sum_t(env_j²)) + sum_sq = torch.sum(env ** 2, dim=3) + den = torch.sqrt(torch.einsum('efi,efj->efij', sum_sq, sum_sq)) + + con = num / den + return con.cpu().numpy() + + +# Numba JIT kernel (module-level for caching) +if NUMBA_AVAILABLE: + @njit(parallel=True, cache=True) + def _envcorr_numba_kernel(env): + """ + Fused Pearson correlation on amplitude envelopes. + + For each (epoch, freq): + 1. Pre-compute per-channel mean and sum-of-squared-deviations + 2. Pearson correlation for upper triangle, copy by symmetry + """ + n_ep, n_freq, n_ch, n_t = env.shape + con = np.zeros((n_ep, n_freq, n_ch, n_ch)) + + for e in prange(n_ep): + for f in range(n_freq): + # Pre-compute mean and sum_sq per channel + mu = np.zeros(n_ch) + ss = np.zeros(n_ch) + for ch in range(n_ch): + s = 0.0 + for t in range(n_t): + s += env[e, f, ch, t] + mu[ch] = s / n_t + sq = 0.0 + for t in range(n_t): + d = env[e, f, ch, t] - mu[ch] + sq += d * d + ss[ch] = sq + + # Pearson correlation for upper triangle + for i in range(n_ch): + for j in range(i, n_ch): + num = 0.0 + for t in range(n_t): + num += (env[e, f, i, t] - mu[i]) * (env[e, f, j, t] - mu[j]) + denom = np.sqrt(ss[i] * ss[j]) + val = num / denom if denom > 0 else 0.0 + con[e, f, i, j] = val + con[e, f, j, i] = val # symmetry + + return con diff --git a/hypyp/sync/imaginary_coh.py b/hypyp/sync/imaginary_coh.py index ee24315..e2b8af7 100644 --- a/hypyp/sync/imaginary_coh.py +++ b/hypyp/sync/imaginary_coh.py @@ -7,7 +7,14 @@ import numpy as np -from .base import BaseMetric, multiply_conjugate +from .base import BaseMetric, multiply_conjugate, TORCH_AVAILABLE, NUMBA_AVAILABLE + +if TORCH_AVAILABLE: + import torch + from .base import multiply_conjugate_torch + +if NUMBA_AVAILABLE: + from numba import njit, prange class ImCoh(BaseMetric): @@ -51,8 +58,19 @@ def compute(self, complex_signal: np.ndarray, n_samp: int, Imaginary Coherence connectivity matrix with shape (n_epoch, n_freq, 2*n_ch, 2*n_ch). """ + if self._backend == 'cuda_kernel': + return self._compute_cuda(complex_signal, n_samp, transpose_axes) + elif self._backend == 'torch': + return self._compute_torch(complex_signal, n_samp, transpose_axes) + elif self._backend == 'numba': + return self._compute_numba(complex_signal, n_samp, transpose_axes) return self._compute_numpy(complex_signal, n_samp, transpose_axes) - + + def _compute_cuda(self, complex_signal, n_samp, transpose_axes): + """CUDA kernel for Imaginary Coherence.""" + from .kernels.cuda_amplitude import imcoh_cuda + return imcoh_cuda(complex_signal) + def _compute_numpy(self, complex_signal: np.ndarray, n_samp: int, transpose_axes: tuple) -> np.ndarray: """NumPy implementation of Imaginary Coherence.""" @@ -63,3 +81,89 @@ def _compute_numpy(self, complex_signal: np.ndarray, n_samp: int, con = np.abs(np.imag(dphi)) / np.sqrt(np.einsum('nil,nik->nilk', np.nansum(amp, axis=3), np.nansum(amp, axis=3))) return con + + def _compute_numba(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """ + Numba JIT implementation of Imaginary Coherence. + + Same fused kernel as Coh but returns |Im(cross-spectrum)| instead + of |cross-spectrum|. This keeps only the non-zero-lag component, + rejecting volume conduction artifacts. + """ + c = np.real(complex_signal) + s = np.imag(complex_signal) + return _imcoh_numba_kernel(c, s) + + def _compute_torch(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """ + PyTorch implementation of Imaginary Coherence. + + Computes Im(X_i * conj(X_j)) = s_i*c_j - c_i*s_j directly with + 2 real-valued einsum, instead of building the full complex cross-spectrum + via multiply_conjugate_torch (4 einsum + complex tensor). This halves + GPU memory usage and avoids MPS corruption on large signals. + + MPS uses float32; CPU/CUDA uses float64. + """ + device = self._device + float_type = torch.float32 if device == 'mps' else torch.float64 + + sig = torch.from_numpy(complex_signal).to(device=device, + dtype=torch.complex64 if device == 'mps' + else torch.complex128) + c, s = sig.real, sig.imag + + # Im(X_i * conj(X_j)) = s_i*c_j - c_i*s_j — 2 einsum, no complex tensor + formula = 'efit,efjt->efij' + im_cross = torch.einsum(formula, s, c) - torch.einsum(formula, c, s) + + # Power normalization: sqrt(sum|X_i|² * sum|X_j|²) + amp = c ** 2 + s ** 2 # |X|² without creating complex abs + power = torch.sum(amp, dim=3) + den = torch.sqrt(torch.einsum('efi,efj->efij', power, power)) + + con = torch.abs(im_cross) / den + return con.cpu().numpy() + + +# Numba JIT kernel (module-level for caching) +if NUMBA_AVAILABLE: + @njit(parallel=True, cache=True) + def _imcoh_numba_kernel(c, s): + """ + Fused ImCoh: cross-spectrum imaginary part + power normalization. + + Same as Coh kernel but returns |im_sum| / sqrt(pow_i * pow_j) + instead of sqrt(re_sum² + im_sum²) / sqrt(pow_i * pow_j). + """ + n_ep, n_freq, n_ch, n_t = c.shape + con = np.zeros((n_ep, n_freq, n_ch, n_ch)) + + for e in prange(n_ep): + for f in range(n_freq): + # Pre-compute power per channel + power = np.zeros(n_ch) + for ch in range(n_ch): + p = 0.0 + for t in range(n_t): + p += c[e, f, ch, t] ** 2 + s[e, f, ch, t] ** 2 + power[ch] = p + + # Cross-spectrum imaginary part for all pairs + for i in range(n_ch): + for j in range(i, n_ch): + im_sum = 0.0 + for t in range(n_t): + # Im(X_i * conj(X_j)) = s_i*c_j - c_i*s_j + im_sum += s[e, f, i, t] * c[e, f, j, t] - c[e, f, i, t] * s[e, f, j, t] + denom = np.sqrt(power[i] * power[j]) + if denom > 0: + val = np.abs(im_sum) / denom + else: + val = 0.0 + con[e, f, i, j] = val + con[e, f, j, i] = val # symmetry + + return con diff --git a/hypyp/sync/kernels/__init__.py b/hypyp/sync/kernels/__init__.py new file mode 100644 index 0000000..dc7003c --- /dev/null +++ b/hypyp/sync/kernels/__init__.py @@ -0,0 +1,23 @@ +""" +Custom GPU kernels for sync metrics. + +Provides Metal (Apple Silicon) and CUDA (NVIDIA) implementations +for metrics that cannot be efficiently expressed with torch operations +(e.g., PLI, wPLI — non-linear per-timepoint operations). +""" + +# Metal availability (Apple Silicon via PyObjC) +try: + import Metal as _Metal + METAL_AVAILABLE = True +except ImportError: + METAL_AVAILABLE = False + +# CUDA availability (NVIDIA via CuPy) +try: + import cupy as _cp + CUPY_AVAILABLE = True +except ImportError: + CUPY_AVAILABLE = False + +__all__ = ["METAL_AVAILABLE", "CUPY_AVAILABLE"] diff --git a/hypyp/sync/kernels/_cuda_dispatch.py b/hypyp/sync/kernels/_cuda_dispatch.py new file mode 100644 index 0000000..b2fc391 --- /dev/null +++ b/hypyp/sync/kernels/_cuda_dispatch.py @@ -0,0 +1,68 @@ +""" +Shared CUDA dispatch logic for all pairwise sync metric kernels. + +Uses CuPy RawKernel for inline CUDA source. All kernels use float64 +for exact precision (A100 has 9.7 TFLOPS fp64). +""" + +import numpy as np + +from . import CUPY_AVAILABLE + +if CUPY_AVAILABLE: + import cupy as cp + + +def run_pairwise_kernel(complex_signal, get_kernel_fn): + """ + Shared dispatch for pairwise CUDA kernels. + + Parameters + ---------- + complex_signal : np.ndarray, shape (E, F, C, T) + get_kernel_fn : callable -> CuPy RawKernel + + Returns + ------- + np.ndarray, shape (E, F, C, C), float64 + """ + kernel = get_kernel_fn() + + E, F, C, T = complex_signal.shape + n_ef = E * F + + c_flat = cp.asarray( + np.ascontiguousarray(np.real(complex_signal).reshape(n_ef, C, T)), + dtype=cp.float64) + s_flat = cp.asarray( + np.ascontiguousarray(np.imag(complex_signal).reshape(n_ef, C, T)), + dtype=cp.float64) + + # Upper-triangle pair indices + idx_i, idx_j = [], [] + for i in range(C): + for j in range(i, C): + idx_i.append(i) + idx_j.append(j) + pairs_i = cp.asarray(np.array(idx_i, dtype=np.int32)) + pairs_j = cp.asarray(np.array(idx_j, dtype=np.int32)) + n_pairs = len(idx_i) + + out = cp.zeros((n_ef, C, C), dtype=cp.float64) + + total_threads = n_ef * n_pairs + block_size = 256 + grid_size = (total_threads + block_size - 1) // block_size + + kernel( + (grid_size,), (block_size,), + (s_flat, c_flat, out, pairs_i, pairs_j, + n_ef, C, T, n_pairs) + ) + + result = cp.asnumpy(out).reshape(E, F, C, C) + + # Explicit cleanup: force immediate GPU memory release (not relying on GC) + cp.get_default_memory_pool().free_all_blocks() + + return result diff --git a/hypyp/sync/kernels/_metal_dispatch.py b/hypyp/sync/kernels/_metal_dispatch.py new file mode 100644 index 0000000..9f0b1bb --- /dev/null +++ b/hypyp/sync/kernels/_metal_dispatch.py @@ -0,0 +1,120 @@ +""" +Shared Metal dispatch logic for all pairwise sync metric kernels. + +All kernels share the same buffer layout: +- buffer(0): s (imaginary parts), float32 +- buffer(1): c (real parts), float32 +- buffer(2): output, float32 +- buffer(3): pair indices i, uint32 +- buffer(4): pair indices j, uint32 +- buffer(5-8): constants (n_ef, n_ch, n_t, n_pairs) + +ACCorr uses an extended layout with buffer(2) = angle and buffer(3) = output, +so it has its own dispatch function. +""" + +import struct + +import numpy as np + +from . import METAL_AVAILABLE + +if METAL_AVAILABLE: + import Metal + + +def make_const_buffer(device, value): + """Create a Metal buffer containing a single uint32 constant.""" + return device.newBufferWithBytes_length_options_( + struct.pack('I', value), 4, Metal.MTLResourceStorageModeShared) + + +def run_pairwise_kernel(complex_signal, compile_fn): + """ + Shared dispatch for pairwise Metal kernels with standard buffer layout. + + Extracts real/imag as float32, builds upper-triangle pair indices, + dispatches the kernel, and reads back the result. + + Parameters + ---------- + complex_signal : np.ndarray, shape (E, F, C, T) + compile_fn : callable -> (device, pipeline) + + Returns + ------- + np.ndarray, shape (E, F, C, C), float32 + """ + device, pipeline = compile_fn() + + E, F, C, T = complex_signal.shape + n_ef = E * F + + c_flat = np.ascontiguousarray(np.real(complex_signal).reshape(n_ef, C, T), + dtype=np.float32) + s_flat = np.ascontiguousarray(np.imag(complex_signal).reshape(n_ef, C, T), + dtype=np.float32) + + # Upper-triangle pair indices + idx_i, idx_j = [], [] + for i in range(C): + for j in range(i, C): + idx_i.append(i) + idx_j.append(j) + idx_i = np.array(idx_i, dtype=np.uint32) + idx_j = np.array(idx_j, dtype=np.uint32) + n_pairs = len(idx_i) + + # Metal buffers + buf_s = device.newBufferWithBytes_length_options_( + s_flat.tobytes(), s_flat.nbytes, Metal.MTLResourceStorageModeShared) + buf_c = device.newBufferWithBytes_length_options_( + c_flat.tobytes(), c_flat.nbytes, Metal.MTLResourceStorageModeShared) + out_nbytes = n_ef * C * C * 4 + buf_out = device.newBufferWithLength_options_( + out_nbytes, Metal.MTLResourceStorageModeShared) + buf_pi = device.newBufferWithBytes_length_options_( + idx_i.tobytes(), idx_i.nbytes, Metal.MTLResourceStorageModeShared) + buf_pj = device.newBufferWithBytes_length_options_( + idx_j.tobytes(), idx_j.nbytes, Metal.MTLResourceStorageModeShared) + + # Dispatch + try: + queue = device.newCommandQueue() + cmd_buffer = queue.commandBuffer() + encoder = cmd_buffer.computeCommandEncoder() + + encoder.setComputePipelineState_(pipeline) + encoder.setBuffer_offset_atIndex_(buf_s, 0, 0) + encoder.setBuffer_offset_atIndex_(buf_c, 0, 1) + encoder.setBuffer_offset_atIndex_(buf_out, 0, 2) + encoder.setBuffer_offset_atIndex_(buf_pi, 0, 3) + encoder.setBuffer_offset_atIndex_(buf_pj, 0, 4) + encoder.setBuffer_offset_atIndex_(make_const_buffer(device, n_ef), 0, 5) + encoder.setBuffer_offset_atIndex_(make_const_buffer(device, C), 0, 6) + encoder.setBuffer_offset_atIndex_(make_const_buffer(device, T), 0, 7) + encoder.setBuffer_offset_atIndex_(make_const_buffer(device, n_pairs), 0, 8) + + total_threads = n_ef * n_pairs + threads_per_group = min(256, pipeline.maxTotalThreadsPerThreadgroup()) + + encoder.dispatchThreads_threadsPerThreadgroup_( + Metal.MTLSize(total_threads, 1, 1), + Metal.MTLSize(threads_per_group, 1, 1)) + encoder.endEncoding() + + cmd_buffer.commit() + cmd_buffer.waitUntilCompleted() + + out_ptr = buf_out.contents() + membuf = out_ptr.as_buffer(out_nbytes) + result = np.frombuffer(membuf, dtype=np.float32).copy().reshape(n_ef, C, C) + + return result.reshape(E, F, C, C) + finally: + # Critical: Release all Metal buffers to prevent GPU memory leak + buf_s.release() + buf_c.release() + buf_out.release() + buf_pi.release() + buf_pj.release() diff --git a/hypyp/sync/kernels/cuda_accorr.py b/hypyp/sync/kernels/cuda_accorr.py new file mode 100644 index 0000000..76b60dc --- /dev/null +++ b/hypyp/sync/kernels/cuda_accorr.py @@ -0,0 +1,128 @@ +""" +CUDA kernel for ACCorr (Adjusted Circular Correlation). +Float64 for exact precision on NVIDIA GPUs. + +ACCorr requires a custom dispatch (not run_pairwise_kernel) because +it needs an extra angle buffer for the sin^2 denominator in pass 2. +""" + +import numpy as np + +from . import CUPY_AVAILABLE + +if CUPY_AVAILABLE: + import cupy as cp + + +_ACCORR_SOURCE = r""" +extern "C" __global__ void accorr_kernel( + const double* __restrict__ s, + const double* __restrict__ c, + const double* __restrict__ angle, + double* __restrict__ out, + const int* __restrict__ pairs_i, + const int* __restrict__ pairs_j, + int n_ef, int n_ch, int n_t, int n_pairs) +{ + int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid >= n_ef * n_pairs) return; + int ef = gid / n_pairs, p = gid % n_pairs; + int i = pairs_i[p], j = pairs_j[p]; + int base = ef * n_ch * n_t; + + // Pass 1: cross-products over T + double cc=0, ss=0, cs=0, sc=0; + for (int t = 0; t < n_t; t++) { + double ci=c[base+i*n_t+t], si=s[base+i*n_t+t]; + double cj=c[base+j*n_t+t], sj=s[base+j*n_t+t]; + cc += ci*cj; ss += si*sj; + cs += ci*sj; sc += si*cj; + } + + double re_conj = cc + ss; + double im_conj = -(cs - sc); + double re_prod = cc - ss; + double im_prod = cs + sc; + + double r_minus = sqrt(re_conj*re_conj + im_conj*im_conj); + double r_plus = sqrt(re_prod*re_prod + im_prod*im_prod); + double num = r_minus - r_plus; + + double mean_diff = atan2(im_conj, re_conj); + double mean_sum = atan2(im_prod, re_prod); + double n_adj = -0.5 * (mean_diff - mean_sum); + double m_adj = mean_diff + n_adj; + + // Pass 2: sin^2 adjusted phases over T + double sum_x2=0, sum_y2=0; + for (int t = 0; t < n_t; t++) { + double ai = angle[base+i*n_t+t]; + double aj = angle[base+j*n_t+t]; + double sx = sin(ai - m_adj); + double sy = sin(aj - n_adj); + sum_x2 += sx*sx; sum_y2 += sy*sy; + } + + double den = 2.0 * sqrt(sum_x2 * sum_y2); + double v = (den > 0.0) ? (num / den) : 0.0; + int ob = ef*n_ch*n_ch; + out[ob+i*n_ch+j] = v; out[ob+j*n_ch+i] = v; +} +""" + +_accorr_kernel = None +def _get_accorr(): + global _accorr_kernel + if _accorr_kernel is None: + _accorr_kernel = cp.RawKernel(_ACCORR_SOURCE, "accorr_kernel") + return _accorr_kernel + + +def accorr_cuda(complex_signal): + """ + ACCorr via CUDA. Two-pass: cross-products + sin^2 denominator. Float64. + + Custom dispatch (not run_pairwise_kernel) because ACCorr needs an + extra angle buffer for the sin^2 denominator in pass 2. + """ + kernel = _get_accorr() + + E, F, C, T = complex_signal.shape + n_ef = E * F + + z = complex_signal / np.abs(complex_signal) + c_flat = cp.asarray( + np.ascontiguousarray(np.real(z).reshape(n_ef, C, T)), dtype=cp.float64) + s_flat = cp.asarray( + np.ascontiguousarray(np.imag(z).reshape(n_ef, C, T)), dtype=cp.float64) + angle_flat = cp.asarray( + np.ascontiguousarray(np.angle(complex_signal).reshape(n_ef, C, T)), + dtype=cp.float64) + + idx_i, idx_j = [], [] + for i in range(C): + for j in range(i, C): + idx_i.append(i) + idx_j.append(j) + pairs_i = cp.asarray(np.array(idx_i, dtype=np.int32)) + pairs_j = cp.asarray(np.array(idx_j, dtype=np.int32)) + n_pairs = len(idx_i) + + out = cp.zeros((n_ef, C, C), dtype=cp.float64) + + total_threads = n_ef * n_pairs + block_size = 256 + grid_size = (total_threads + block_size - 1) // block_size + + kernel( + (grid_size,), (block_size,), + (s_flat, c_flat, angle_flat, out, pairs_i, pairs_j, + n_ef, C, T, n_pairs) + ) + + result = cp.asnumpy(out).reshape(E, F, C, C) + + # Explicit cleanup: force immediate GPU memory release (not relying on GC) + cp.get_default_memory_pool().free_all_blocks() + + return result diff --git a/hypyp/sync/kernels/cuda_amplitude.py b/hypyp/sync/kernels/cuda_amplitude.py new file mode 100644 index 0000000..07b2acf --- /dev/null +++ b/hypyp/sync/kernels/cuda_amplitude.py @@ -0,0 +1,206 @@ +""" +CUDA kernels for amplitude-based sync metrics: Coh, ImCoh, EnvCorr, PowCorr. +All float64 for exact precision on NVIDIA GPUs. +""" + +import numpy as np + +from . import CUPY_AVAILABLE +from ._cuda_dispatch import run_pairwise_kernel + +if CUPY_AVAILABLE: + import cupy as cp + + +# ========================================================================= +# Coh +# ========================================================================= + +_COH_SOURCE = r""" +extern "C" __global__ void coh_kernel( + const double* __restrict__ s, const double* __restrict__ c, + double* __restrict__ out, + const int* __restrict__ pairs_i, const int* __restrict__ pairs_j, + int n_ef, int n_ch, int n_t, int n_pairs) +{ + int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid >= n_ef * n_pairs) return; + int ef = gid / n_pairs, p = gid % n_pairs; + int i = pairs_i[p], j = pairs_j[p]; + int base = ef * n_ch * n_t; + double re=0, im=0, pi=0, pj=0; + for (int t = 0; t < n_t; t++) { + double ci=c[base+i*n_t+t], si=s[base+i*n_t+t]; + double cj=c[base+j*n_t+t], sj=s[base+j*n_t+t]; + re += ci*cj + si*sj; + im += si*cj - ci*sj; + pi += ci*ci + si*si; + pj += cj*cj + sj*sj; + } + double cross = sqrt(re*re + im*im); + double den = sqrt(pi * pj); + double v = (den > 0.0) ? (cross / den) : 0.0; + int ob = ef*n_ch*n_ch; + out[ob+i*n_ch+j] = v; out[ob+j*n_ch+i] = v; +} +""" + +_coh_kernel = None +def _get_coh(): + global _coh_kernel + if _coh_kernel is None: + _coh_kernel = cp.RawKernel(_COH_SOURCE, "coh_kernel") + return _coh_kernel + +def coh_cuda(complex_signal): + """Coh via CUDA. Float64.""" + return run_pairwise_kernel(complex_signal, _get_coh) + + +# ========================================================================= +# ImCoh +# ========================================================================= + +_IMCOH_SOURCE = r""" +extern "C" __global__ void imcoh_kernel( + const double* __restrict__ s, const double* __restrict__ c, + double* __restrict__ out, + const int* __restrict__ pairs_i, const int* __restrict__ pairs_j, + int n_ef, int n_ch, int n_t, int n_pairs) +{ + int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid >= n_ef * n_pairs) return; + int ef = gid / n_pairs, p = gid % n_pairs; + int i = pairs_i[p], j = pairs_j[p]; + int base = ef * n_ch * n_t; + double im=0, pi=0, pj=0; + for (int t = 0; t < n_t; t++) { + double ci=c[base+i*n_t+t], si=s[base+i*n_t+t]; + double cj=c[base+j*n_t+t], sj=s[base+j*n_t+t]; + im += si*cj - ci*sj; + pi += ci*ci + si*si; + pj += cj*cj + sj*sj; + } + double den = sqrt(pi * pj); + double v = (den > 0.0) ? (fabs(im) / den) : 0.0; + int ob = ef*n_ch*n_ch; + out[ob+i*n_ch+j] = v; out[ob+j*n_ch+i] = v; +} +""" + +_imcoh_kernel = None +def _get_imcoh(): + global _imcoh_kernel + if _imcoh_kernel is None: + _imcoh_kernel = cp.RawKernel(_IMCOH_SOURCE, "imcoh_kernel") + return _imcoh_kernel + +def imcoh_cuda(complex_signal): + """ImCoh via CUDA. Float64.""" + return run_pairwise_kernel(complex_signal, _get_imcoh) + + +# ========================================================================= +# EnvCorr +# ========================================================================= + +_ENVCORR_SOURCE = r""" +extern "C" __global__ void envcorr_kernel( + const double* __restrict__ s, const double* __restrict__ c, + double* __restrict__ out, + const int* __restrict__ pairs_i, const int* __restrict__ pairs_j, + int n_ef, int n_ch, int n_t, int n_pairs) +{ + int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid >= n_ef * n_pairs) return; + int ef = gid / n_pairs, p = gid % n_pairs; + int i = pairs_i[p], j = pairs_j[p]; + int base = ef * n_ch * n_t; + // Pass 1: mean envelope + double si_sum=0, sj_sum=0; + for (int t = 0; t < n_t; t++) { + double ci=c[base+i*n_t+t], si=s[base+i*n_t+t]; + double cj=c[base+j*n_t+t], sj=s[base+j*n_t+t]; + si_sum += sqrt(ci*ci + si*si); + sj_sum += sqrt(cj*cj + sj*sj); + } + double mu_i = si_sum / n_t, mu_j = sj_sum / n_t; + // Pass 2: Pearson + double num=0, di2=0, dj2=0; + for (int t = 0; t < n_t; t++) { + double ci=c[base+i*n_t+t], si=s[base+i*n_t+t]; + double cj=c[base+j*n_t+t], sj=s[base+j*n_t+t]; + double di = sqrt(ci*ci + si*si) - mu_i; + double dj = sqrt(cj*cj + sj*sj) - mu_j; + num += di*dj; di2 += di*di; dj2 += dj*dj; + } + double den = sqrt(di2 * dj2); + double v = (den > 0.0) ? (num / den) : 0.0; + int ob = ef*n_ch*n_ch; + out[ob+i*n_ch+j] = v; out[ob+j*n_ch+i] = v; +} +""" + +_envcorr_kernel = None +def _get_envcorr(): + global _envcorr_kernel + if _envcorr_kernel is None: + _envcorr_kernel = cp.RawKernel(_ENVCORR_SOURCE, "envcorr_kernel") + return _envcorr_kernel + +def envcorr_cuda(complex_signal): + """EnvCorr via CUDA. Pearson on envelopes. Float64.""" + return run_pairwise_kernel(complex_signal, _get_envcorr) + + +# ========================================================================= +# PowCorr +# ========================================================================= + +_POWCORR_SOURCE = r""" +extern "C" __global__ void powcorr_kernel( + const double* __restrict__ s, const double* __restrict__ c, + double* __restrict__ out, + const int* __restrict__ pairs_i, const int* __restrict__ pairs_j, + int n_ef, int n_ch, int n_t, int n_pairs) +{ + int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid >= n_ef * n_pairs) return; + int ef = gid / n_pairs, p = gid % n_pairs; + int i = pairs_i[p], j = pairs_j[p]; + int base = ef * n_ch * n_t; + // Pass 1: mean power + double si_sum=0, sj_sum=0; + for (int t = 0; t < n_t; t++) { + double ci=c[base+i*n_t+t], si=s[base+i*n_t+t]; + double cj=c[base+j*n_t+t], sj=s[base+j*n_t+t]; + si_sum += ci*ci + si*si; + sj_sum += cj*cj + sj*sj; + } + double mu_i = si_sum / n_t, mu_j = sj_sum / n_t; + // Pass 2: Pearson + double num=0, di2=0, dj2=0; + for (int t = 0; t < n_t; t++) { + double ci=c[base+i*n_t+t], si=s[base+i*n_t+t]; + double cj=c[base+j*n_t+t], sj=s[base+j*n_t+t]; + double di = (ci*ci + si*si) - mu_i; + double dj = (cj*cj + sj*sj) - mu_j; + num += di*dj; di2 += di*di; dj2 += dj*dj; + } + double den = sqrt(di2 * dj2); + double v = (den > 0.0) ? (num / den) : 0.0; + int ob = ef*n_ch*n_ch; + out[ob+i*n_ch+j] = v; out[ob+j*n_ch+i] = v; +} +""" + +_powcorr_kernel = None +def _get_powcorr(): + global _powcorr_kernel + if _powcorr_kernel is None: + _powcorr_kernel = cp.RawKernel(_POWCORR_SOURCE, "powcorr_kernel") + return _powcorr_kernel + +def powcorr_cuda(complex_signal): + """PowCorr via CUDA. Pearson on power. Float64.""" + return run_pairwise_kernel(complex_signal, _get_powcorr) diff --git a/hypyp/sync/kernels/cuda_phase.py b/hypyp/sync/kernels/cuda_phase.py new file mode 100644 index 0000000..e809cee --- /dev/null +++ b/hypyp/sync/kernels/cuda_phase.py @@ -0,0 +1,185 @@ +""" +CUDA kernels for phase-based sync metrics: PLI, wPLI, PLV, CCorr. +All float64 for exact precision on NVIDIA GPUs. +""" + +import numpy as np + +from . import CUPY_AVAILABLE +from ._cuda_dispatch import run_pairwise_kernel + +if CUPY_AVAILABLE: + import cupy as cp + + +# ========================================================================= +# PLI +# ========================================================================= + +_PLI_SOURCE = r""" +extern "C" __global__ void pli_kernel( + const double* __restrict__ s, const double* __restrict__ c, + double* __restrict__ out, + const int* __restrict__ pairs_i, const int* __restrict__ pairs_j, + int n_ef, int n_ch, int n_t, int n_pairs) +{ + int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid >= n_ef * n_pairs) return; + int ef = gid / n_pairs, p = gid % n_pairs; + int i = pairs_i[p], j = pairs_j[p]; + if (i == j) { out[ef*n_ch*n_ch + i*n_ch+j] = 0.0; return; } + int base = ef * n_ch * n_t; + double sign_sum = 0.0; + for (int t = 0; t < n_t; t++) { + double im = s[base+i*n_t+t]*c[base+j*n_t+t] - c[base+i*n_t+t]*s[base+j*n_t+t]; + if (im > 0.0) sign_sum += 1.0; + else if (im < 0.0) sign_sum -= 1.0; + } + double v = fabs(sign_sum) / (double)n_t; + int ob = ef*n_ch*n_ch; + out[ob+i*n_ch+j] = v; out[ob+j*n_ch+i] = v; +} +""" + +_pli_kernel = None +def _get_pli(): + global _pli_kernel + if _pli_kernel is None: + _pli_kernel = cp.RawKernel(_PLI_SOURCE, "pli_kernel") + return _pli_kernel + +def pli_cuda(complex_signal): + """PLI via CUDA. Float64.""" + return run_pairwise_kernel(complex_signal, _get_pli) + + +# ========================================================================= +# wPLI +# ========================================================================= + +_WPLI_SOURCE = r""" +extern "C" __global__ void wpli_kernel( + const double* __restrict__ s, const double* __restrict__ c, + double* __restrict__ out, + const int* __restrict__ pairs_i, const int* __restrict__ pairs_j, + int n_ef, int n_ch, int n_t, int n_pairs) +{ + int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid >= n_ef * n_pairs) return; + int ef = gid / n_pairs, p = gid % n_pairs; + int i = pairs_i[p], j = pairs_j[p]; + if (i == j) { out[ef*n_ch*n_ch + i*n_ch+j] = 0.0; return; } + int base = ef * n_ch * n_t; + double im_sum = 0.0, abs_sum = 0.0; + for (int t = 0; t < n_t; t++) { + double im = s[base+i*n_t+t]*c[base+j*n_t+t] - c[base+i*n_t+t]*s[base+j*n_t+t]; + im_sum += im; abs_sum += fabs(im); + } + double v = (abs_sum > 0.0) ? (fabs(im_sum) / abs_sum) : 0.0; + int ob = ef*n_ch*n_ch; + out[ob+i*n_ch+j] = v; out[ob+j*n_ch+i] = v; +} +""" + +_wpli_kernel = None +def _get_wpli(): + global _wpli_kernel + if _wpli_kernel is None: + _wpli_kernel = cp.RawKernel(_WPLI_SOURCE, "wpli_kernel") + return _wpli_kernel + +def wpli_cuda(complex_signal): + """wPLI via CUDA. Float64.""" + return run_pairwise_kernel(complex_signal, _get_wpli) + + +# ========================================================================= +# PLV +# ========================================================================= + +_PLV_SOURCE = r""" +extern "C" __global__ void plv_kernel( + const double* __restrict__ s, const double* __restrict__ c, + double* __restrict__ out, + const int* __restrict__ pairs_i, const int* __restrict__ pairs_j, + int n_ef, int n_ch, int n_t, int n_pairs) +{ + int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid >= n_ef * n_pairs) return; + int ef = gid / n_pairs, p = gid % n_pairs; + int i = pairs_i[p], j = pairs_j[p]; + int base = ef * n_ch * n_t; + double re = 0.0, im = 0.0; + for (int t = 0; t < n_t; t++) { + double ci=c[base+i*n_t+t], si=s[base+i*n_t+t]; + double cj=c[base+j*n_t+t], sj=s[base+j*n_t+t]; + re += ci*cj + si*sj; + im += si*cj - ci*sj; + } + double v = sqrt(re*re + im*im) / (double)n_t; + int ob = ef*n_ch*n_ch; + out[ob+i*n_ch+j] = v; out[ob+j*n_ch+i] = v; +} +""" + +_plv_kernel = None +def _get_plv(): + global _plv_kernel + if _plv_kernel is None: + _plv_kernel = cp.RawKernel(_PLV_SOURCE, "plv_kernel") + return _plv_kernel + +def plv_cuda(complex_signal): + """PLV via CUDA. Phase-normalizes then cross-spectrum. Float64.""" + z = complex_signal / np.abs(complex_signal) + return run_pairwise_kernel(z, _get_plv) + + +# ========================================================================= +# CCorr +# ========================================================================= + +_CCORR_SOURCE = r""" +extern "C" __global__ void ccorr_kernel( + const double* __restrict__ s, const double* __restrict__ c, + double* __restrict__ out, + const int* __restrict__ pairs_i, const int* __restrict__ pairs_j, + int n_ef, int n_ch, int n_t, int n_pairs) +{ + int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (gid >= n_ef * n_pairs) return; + int ef = gid / n_pairs, p = gid % n_pairs; + int i = pairs_i[p], j = pairs_j[p]; + int base = ef * n_ch * n_t; + // Pass 1: C_bar, S_bar + double cbi=0, sbi=0, cbj=0, sbj=0; + for (int t = 0; t < n_t; t++) { + cbi += c[base+i*n_t+t]; sbi += s[base+i*n_t+t]; + cbj += c[base+j*n_t+t]; sbj += s[base+j*n_t+t]; + } + cbi /= n_t; sbi /= n_t; cbj /= n_t; sbj /= n_t; + // Pass 2: Pearson + double num=0, di2=0, dj2=0; + for (int t = 0; t < n_t; t++) { + double di = s[base+i*n_t+t]*cbi - c[base+i*n_t+t]*sbi; + double dj = s[base+j*n_t+t]*cbj - c[base+j*n_t+t]*sbj; + num += di*dj; di2 += di*di; dj2 += dj*dj; + } + double den = sqrt(di2 * dj2); + double v = (den > 0.0) ? (fabs(num) / den) : 0.0; + int ob = ef*n_ch*n_ch; + out[ob+i*n_ch+j] = v; out[ob+j*n_ch+i] = v; +} +""" + +_ccorr_kernel = None +def _get_ccorr(): + global _ccorr_kernel + if _ccorr_kernel is None: + _ccorr_kernel = cp.RawKernel(_CCORR_SOURCE, "ccorr_kernel") + return _ccorr_kernel + +def ccorr_cuda(complex_signal): + """CCorr via CUDA. Phase-normalizes then angle-free Pearson. Float64.""" + z = complex_signal / np.abs(complex_signal) + return run_pairwise_kernel(z, _get_ccorr) diff --git a/hypyp/sync/kernels/metal_accorr.py b/hypyp/sync/kernels/metal_accorr.py new file mode 100644 index 0000000..a96f70d --- /dev/null +++ b/hypyp/sync/kernels/metal_accorr.py @@ -0,0 +1,194 @@ +""" +Metal kernel for Adjusted Circular Correlation (ACCorr). + +ACCorr is the most complex metric — it requires a two-pass kernel: +- Pass 1: cross-products (cc, ss, cs, sc) → numerator + phase adjustments +- Pass 2: sin^2 adjusted phases → denominator + +This module uses an extended buffer layout (3 input buffers: s, c, angle) +instead of the standard 2-buffer layout used by other metrics. +""" + +import struct +from functools import lru_cache + +import numpy as np + +from . import METAL_AVAILABLE +from ._metal_dispatch import make_const_buffer + +if METAL_AVAILABLE: + import Metal + + +_ACCORR_SHADER = """ +#include +using namespace metal; + +kernel void accorr_kernel( + device const float* s [[buffer(0)]], + device const float* c [[buffer(1)]], + device const float* angle [[buffer(2)]], + device float* out [[buffer(3)]], + device const uint* pairs_i [[buffer(4)]], + device const uint* pairs_j [[buffer(5)]], + constant uint& n_ef [[buffer(6)]], + constant uint& n_ch [[buffer(7)]], + constant uint& n_t [[buffer(8)]], + constant uint& n_pairs [[buffer(9)]], + uint gid [[thread_position_in_grid]]) +{ + uint total = n_ef * n_pairs; + if (gid >= total) return; + + uint ef_idx = gid / n_pairs; + uint pair_idx = gid % n_pairs; + uint i = pairs_i[pair_idx]; + uint j = pairs_j[pair_idx]; + uint base = ef_idx * n_ch * n_t; + + // Pass 1: cross-products over T + float cc_sum = 0.0, ss_sum = 0.0, cs_sum = 0.0, sc_sum = 0.0; + for (uint t = 0; t < n_t; t++) { + float ci = c[base + i * n_t + t]; + float si = s[base + i * n_t + t]; + float cj = c[base + j * n_t + t]; + float sj = s[base + j * n_t + t]; + cc_sum += ci * cj; + ss_sum += si * sj; + cs_sum += ci * sj; + sc_sum += si * cj; + } + + float re_conj = cc_sum + ss_sum; + float im_conj = -(cs_sum - sc_sum); + float re_prod = cc_sum - ss_sum; + float im_prod = cs_sum + sc_sum; + + float r_minus = sqrt(re_conj * re_conj + im_conj * im_conj); + float r_plus = sqrt(re_prod * re_prod + im_prod * im_prod); + float num = r_minus - r_plus; + + float mean_diff = atan2(im_conj, re_conj); + float mean_sum = atan2(im_prod, re_prod); + float n_adj = -0.5 * (mean_diff - mean_sum); + float m_adj = mean_diff + n_adj; + + // Pass 2: sin^2 adjusted phases over T + float sum_x2 = 0.0, sum_y2 = 0.0; + for (uint t = 0; t < n_t; t++) { + float ai = angle[base + i * n_t + t]; + float aj = angle[base + j * n_t + t]; + float sx = sin(ai - m_adj); + float sy = sin(aj - n_adj); + sum_x2 += sx * sx; + sum_y2 += sy * sy; + } + + float den = 2.0 * sqrt(sum_x2 * sum_y2); + float accorr = (den > 0.0) ? (num / den) : 0.0; + + uint out_base = ef_idx * n_ch * n_ch; + out[out_base + i * n_ch + j] = accorr; + out[out_base + j * n_ch + i] = accorr; +} +""" + + +@lru_cache(maxsize=1) +def _compile_accorr(): + device = Metal.MTLCreateSystemDefaultDevice() + options = Metal.MTLCompileOptions.new() + library, error = device.newLibraryWithSource_options_error_(_ACCORR_SHADER, options, None) + if error: + raise RuntimeError(f"Metal ACCorr shader failed: {error}") + fn = library.newFunctionWithName_("accorr_kernel") + pipeline, error = device.newComputePipelineStateWithFunction_error_(fn, None) + if error: + raise RuntimeError(f"Metal ACCorr pipeline failed: {error}") + return device, pipeline + + +def accorr_metal(complex_signal: np.ndarray) -> np.ndarray: + """ + ACCorr via Metal. Two-pass kernel: cross-products + sin^2 denominator. + + Uses 3 input buffers (s, c, angle) instead of standard 2. + """ + device, pipeline = _compile_accorr() + + E, F, C, T = complex_signal.shape + n_ef = E * F + + z = complex_signal / np.abs(complex_signal) + c_flat = np.ascontiguousarray(np.real(z).reshape(n_ef, C, T), dtype=np.float32) + s_flat = np.ascontiguousarray(np.imag(z).reshape(n_ef, C, T), dtype=np.float32) + angle_flat = np.ascontiguousarray( + np.angle(complex_signal).reshape(n_ef, C, T), dtype=np.float32) + + idx_i, idx_j = [], [] + for i in range(C): + for j in range(i, C): + idx_i.append(i) + idx_j.append(j) + idx_i = np.array(idx_i, dtype=np.uint32) + idx_j = np.array(idx_j, dtype=np.uint32) + n_pairs = len(idx_i) + + # Metal buffers — extended layout for ACCorr + buf_s = device.newBufferWithBytes_length_options_( + s_flat.tobytes(), s_flat.nbytes, Metal.MTLResourceStorageModeShared) + buf_c = device.newBufferWithBytes_length_options_( + c_flat.tobytes(), c_flat.nbytes, Metal.MTLResourceStorageModeShared) + buf_angle = device.newBufferWithBytes_length_options_( + angle_flat.tobytes(), angle_flat.nbytes, Metal.MTLResourceStorageModeShared) + out_nbytes = n_ef * C * C * 4 + buf_out = device.newBufferWithLength_options_( + out_nbytes, Metal.MTLResourceStorageModeShared) + buf_pi = device.newBufferWithBytes_length_options_( + idx_i.tobytes(), idx_i.nbytes, Metal.MTLResourceStorageModeShared) + buf_pj = device.newBufferWithBytes_length_options_( + idx_j.tobytes(), idx_j.nbytes, Metal.MTLResourceStorageModeShared) + + # Dispatch + try: + queue = device.newCommandQueue() + cmd_buffer = queue.commandBuffer() + encoder = cmd_buffer.computeCommandEncoder() + + encoder.setComputePipelineState_(pipeline) + encoder.setBuffer_offset_atIndex_(buf_s, 0, 0) + encoder.setBuffer_offset_atIndex_(buf_c, 0, 1) + encoder.setBuffer_offset_atIndex_(buf_angle, 0, 2) + encoder.setBuffer_offset_atIndex_(buf_out, 0, 3) + encoder.setBuffer_offset_atIndex_(buf_pi, 0, 4) + encoder.setBuffer_offset_atIndex_(buf_pj, 0, 5) + encoder.setBuffer_offset_atIndex_(make_const_buffer(device, n_ef), 0, 6) + encoder.setBuffer_offset_atIndex_(make_const_buffer(device, C), 0, 7) + encoder.setBuffer_offset_atIndex_(make_const_buffer(device, T), 0, 8) + encoder.setBuffer_offset_atIndex_(make_const_buffer(device, n_pairs), 0, 9) + + total_threads = n_ef * n_pairs + threads_per_group = min(256, pipeline.maxTotalThreadsPerThreadgroup()) + + encoder.dispatchThreads_threadsPerThreadgroup_( + Metal.MTLSize(total_threads, 1, 1), + Metal.MTLSize(threads_per_group, 1, 1)) + encoder.endEncoding() + + cmd_buffer.commit() + cmd_buffer.waitUntilCompleted() + + out_ptr = buf_out.contents() + membuf = out_ptr.as_buffer(out_nbytes) + result = np.frombuffer(membuf, dtype=np.float32).copy().reshape(n_ef, C, C) + + return result.reshape(E, F, C, C) + finally: + # Critical: Release all Metal buffers to prevent GPU memory leak + buf_s.release() + buf_c.release() + buf_angle.release() + buf_out.release() + buf_pi.release() + buf_pj.release() diff --git a/hypyp/sync/kernels/metal_phase.py b/hypyp/sync/kernels/metal_phase.py new file mode 100644 index 0000000..2bc4e4a --- /dev/null +++ b/hypyp/sync/kernels/metal_phase.py @@ -0,0 +1,157 @@ +""" +Metal kernels for sign-based sync metrics: PLI, wPLI. + +These metrics work on the imaginary part of the cross-spectrum and +cannot be efficiently expressed as batched einsum/BLAS operations, +making custom kernels faster than torch on Apple Silicon. +""" + +from functools import lru_cache + +import numpy as np + +from . import METAL_AVAILABLE +from ._metal_dispatch import run_pairwise_kernel + +if METAL_AVAILABLE: + import Metal + + +# ========================================================================= +# PLI +# ========================================================================= + +_PLI_SHADER = """ +#include +using namespace metal; + +kernel void pli_kernel( + device const float* s [[buffer(0)]], + device const float* c [[buffer(1)]], + device float* out [[buffer(2)]], + device const uint* pairs_i [[buffer(3)]], + device const uint* pairs_j [[buffer(4)]], + constant uint& n_ef [[buffer(5)]], + constant uint& n_ch [[buffer(6)]], + constant uint& n_t [[buffer(7)]], + constant uint& n_pairs [[buffer(8)]], + uint gid [[thread_position_in_grid]]) +{ + uint total = n_ef * n_pairs; + if (gid >= total) return; + + uint ef_idx = gid / n_pairs; + uint pair_idx = gid % n_pairs; + uint i = pairs_i[pair_idx]; + uint j = pairs_j[pair_idx]; + + if (i == j) { + uint out_base = ef_idx * n_ch * n_ch; + out[out_base + i * n_ch + j] = 0.0; + return; + } + + uint base = ef_idx * n_ch * n_t; + float sign_sum = 0.0; + for (uint t = 0; t < n_t; t++) { + float im = fma(s[base + i * n_t + t], c[base + j * n_t + t], + -(c[base + i * n_t + t] * s[base + j * n_t + t])); + if (im > 0.0) sign_sum += 1.0; + else if (im < 0.0) sign_sum -= 1.0; + } + + float pli = abs(sign_sum) / float(n_t); + uint out_base = ef_idx * n_ch * n_ch; + out[out_base + i * n_ch + j] = pli; + out[out_base + j * n_ch + i] = pli; +} +""" + + +@lru_cache(maxsize=1) +def _compile_pli(): + device = Metal.MTLCreateSystemDefaultDevice() + options = Metal.MTLCompileOptions.new() + library, error = device.newLibraryWithSource_options_error_(_PLI_SHADER, options, None) + if error: + raise RuntimeError(f"Metal PLI shader failed: {error}") + fn = library.newFunctionWithName_("pli_kernel") + pipeline, error = device.newComputePipelineStateWithFunction_error_(fn, None) + if error: + raise RuntimeError(f"Metal PLI pipeline failed: {error}") + return device, pipeline + + +def pli_metal(complex_signal): + """PLI via Metal. sign(Im(cross-spectrum)) per timepoint.""" + return run_pairwise_kernel(complex_signal, _compile_pli) + + +# ========================================================================= +# wPLI +# ========================================================================= + +_WPLI_SHADER = """ +#include +using namespace metal; + +kernel void wpli_kernel( + device const float* s [[buffer(0)]], + device const float* c [[buffer(1)]], + device float* out [[buffer(2)]], + device const uint* pairs_i [[buffer(3)]], + device const uint* pairs_j [[buffer(4)]], + constant uint& n_ef [[buffer(5)]], + constant uint& n_ch [[buffer(6)]], + constant uint& n_t [[buffer(7)]], + constant uint& n_pairs [[buffer(8)]], + uint gid [[thread_position_in_grid]]) +{ + uint total = n_ef * n_pairs; + if (gid >= total) return; + + uint ef_idx = gid / n_pairs; + uint pair_idx = gid % n_pairs; + uint i = pairs_i[pair_idx]; + uint j = pairs_j[pair_idx]; + + if (i == j) { + uint out_base = ef_idx * n_ch * n_ch; + out[out_base + i * n_ch + j] = 0.0; + return; + } + + uint base = ef_idx * n_ch * n_t; + float im_sum = 0.0, abs_sum = 0.0; + for (uint t = 0; t < n_t; t++) { + float im = fma(s[base + i * n_t + t], c[base + j * n_t + t], + -(c[base + i * n_t + t] * s[base + j * n_t + t])); + im_sum += im; + abs_sum += fabs(im); + } + + float wpli = (abs_sum > 0.0) ? (fabs(im_sum) / abs_sum) : 0.0; + uint out_base = ef_idx * n_ch * n_ch; + out[out_base + i * n_ch + j] = wpli; + out[out_base + j * n_ch + i] = wpli; +} +""" + + +@lru_cache(maxsize=1) +def _compile_wpli(): + device = Metal.MTLCreateSystemDefaultDevice() + options = Metal.MTLCompileOptions.new() + library, error = device.newLibraryWithSource_options_error_(_WPLI_SHADER, options, None) + if error: + raise RuntimeError(f"Metal wPLI shader failed: {error}") + fn = library.newFunctionWithName_("wpli_kernel") + pipeline, error = device.newComputePipelineStateWithFunction_error_(fn, None) + if error: + raise RuntimeError(f"Metal wPLI pipeline failed: {error}") + return device, pipeline + + +def wpli_metal(complex_signal): + """wPLI via Metal. |sum(Im)| / sum(|Im|) per timepoint.""" + return run_pairwise_kernel(complex_signal, _compile_wpli) diff --git a/hypyp/sync/pli.py b/hypyp/sync/pli.py index a54a117..078b378 100644 --- a/hypyp/sync/pli.py +++ b/hypyp/sync/pli.py @@ -7,20 +7,26 @@ import numpy as np -from .base import BaseMetric, multiply_conjugate_time +from .base import BaseMetric, multiply_conjugate_time, TORCH_AVAILABLE, NUMBA_AVAILABLE + +if TORCH_AVAILABLE: + import torch + +if NUMBA_AVAILABLE: + from numba import njit, prange class PLI(BaseMetric): """ Phase Lag Index (PLI) connectivity metric. - + PLI measures the asymmetry of the distribution of instantaneous phase differences. It is insensitive to volume conduction as it ignores zero-lag interactions. - + Mathematical formulation: PLI = |⟨sign(Im(XY*))⟩| - + References ---------- Stam, C. J., Nolte, G., & Daffertshofer, A. (2007). Phase lag index: @@ -28,32 +34,40 @@ class PLI(BaseMetric): with diminished bias from common sources. Human Brain Mapping, 28(11), 1178-1193. """ - + name = "pli" - + def compute(self, complex_signal: np.ndarray, n_samp: int, transpose_axes: tuple) -> np.ndarray: """ Compute Phase Lag Index. - + Parameters ---------- complex_signal : np.ndarray Complex analytic signals with shape (n_epochs, n_freq, 2*n_channels, n_times). - + n_samp : int Number of time samples. - + transpose_axes : tuple Axes to transpose for matrix multiplication. - + Returns ------- con : np.ndarray PLI connectivity matrix with shape (n_epoch, n_freq, 2*n_ch, 2*n_ch). """ + if self._backend == 'metal': + return self._compute_metal(complex_signal, n_samp, transpose_axes) + elif self._backend == 'cuda_kernel': + return self._compute_cuda(complex_signal, n_samp, transpose_axes) + elif self._backend == 'torch': + return self._compute_torch(complex_signal, n_samp, transpose_axes) + elif self._backend == 'numba': + return self._compute_numba(complex_signal, n_samp, transpose_axes) return self._compute_numpy(complex_signal, n_samp, transpose_axes) - + def _compute_numpy(self, complex_signal: np.ndarray, n_samp: int, transpose_axes: tuple) -> np.ndarray: """NumPy implementation of Phase Lag Index.""" @@ -62,3 +76,110 @@ def _compute_numpy(self, complex_signal: np.ndarray, n_samp: int, dphi = multiply_conjugate_time(c, s, transpose_axes=transpose_axes) con = np.abs(np.mean(np.sign(np.imag(dphi)), axis=4)) return con + + def _compute_metal(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """ + Metal compute shader implementation of PLI on Apple Silicon GPU. + + Each GPU thread processes one (epoch×freq, channel_pair) combination, + looping over timepoints. No intermediate tensor — O(1) memory per thread. + ~3x faster than numba on 256+ channel data. + + Requires: pip install pyobjc-framework-Metal + """ + from .kernels.metal_phase import pli_metal + return pli_metal(complex_signal) + + def _compute_cuda(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """ + CUDA kernel implementation of PLI on NVIDIA GPU. + + Requires: pip install cupy-cuda12x + """ + from .kernels.cuda_phase import pli_cuda + return pli_cuda(complex_signal) + + def _compute_numba(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """ + Numba JIT implementation of PLI with fused kernel. + + Computes Im(X_i * conj(X_j)) = s_i*c_j - c_i*s_j and sign() + directly in the inner loop, eliminating the 5D intermediate tensor. + Memory: O(C²) instead of O(C² × T). Parallelized over epochs. + + Note: PLI uses the raw signal (not phase-normalized). The sign() + operation makes the result invariant to amplitude anyway. + """ + c = np.real(complex_signal) + s = np.imag(complex_signal) + return _pli_numba_kernel(c, s) + + def _compute_torch(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """ + PyTorch implementation of Phase Lag Index using per-channel broadcast. + + For each channel i, broadcasts s_i and c_i against ALL channels j + simultaneously using element-wise ops. No einsum, no gather copies. + + s[:,:,i:i+1,:] is a contiguous slice (zero-copy view on GPU). + The broadcast produces (E, F, C, T) intermediates — the same size + as the input, not C² × T like the einsum approach. + + MPS uses float32 precision; CPU/CUDA uses float64. + """ + device = self._device + float_type = torch.float32 if device == 'mps' else torch.float64 + complex_type = torch.complex64 if device == 'mps' else torch.complex128 + + sig = torch.from_numpy(complex_signal).to(device=device, dtype=complex_type) + n_epochs, n_freq, n_ch, n_times = sig.shape + c, s = sig.real, sig.imag # (E, F, C, T) + + con = torch.zeros((n_epochs, n_freq, n_ch, n_ch), + device=device, dtype=float_type) + + for i in range(n_ch): + # s[:,:,i:i+1,:] is a VIEW (contiguous slice), no copy + # Broadcasting against (E, F, C, T) produces (E, F, C, T) + im = s[:, :, i:i+1, :] * c - c[:, :, i:i+1, :] * s # (E, F, C, T) + con[:, :, i, :] = torch.abs(torch.mean(torch.sign(im), dim=-1)) + + return con.cpu().numpy() + + +# Numba JIT kernel (module-level for caching) +if NUMBA_AVAILABLE: + @njit(parallel=True, cache=True) + def _pli_numba_kernel(c, s): + """ + Fused PLI: sign(Im(cross-spectrum)) averaged over time. + + Im(X_i * conj(X_j)) = s_i*c_j - c_i*s_j + PLI = |mean_t(sign(Im))| + + No 5D tensor — O(C²) memory instead of O(C² × T). + """ + n_ep, n_freq, n_ch, n_t = c.shape + con = np.zeros((n_ep, n_freq, n_ch, n_ch)) + + for e in prange(n_ep): + for f in range(n_freq): + for i in range(n_ch): + for j in range(i, n_ch): + sign_sum = 0.0 + for t in range(n_t): + im = s[e, f, i, t] * c[e, f, j, t] \ + - c[e, f, i, t] * s[e, f, j, t] + if im > 0: + sign_sum += 1.0 + elif im < 0: + sign_sum -= 1.0 + val = abs(sign_sum) / n_t + con[e, f, i, j] = val + con[e, f, j, i] = val + + return con diff --git a/hypyp/sync/plv.py b/hypyp/sync/plv.py index 62d63ee..ce3207d 100644 --- a/hypyp/sync/plv.py +++ b/hypyp/sync/plv.py @@ -7,7 +7,14 @@ import numpy as np -from .base import BaseMetric, multiply_conjugate +from .base import BaseMetric, multiply_conjugate, TORCH_AVAILABLE, NUMBA_AVAILABLE + +if TORCH_AVAILABLE: + import torch + from .base import multiply_conjugate_torch + +if NUMBA_AVAILABLE: + from numba import njit, prange class PLV(BaseMetric): @@ -49,8 +56,19 @@ def compute(self, complex_signal: np.ndarray, n_samp: int, con : np.ndarray PLV connectivity matrix with shape (n_epoch, n_freq, 2*n_ch, 2*n_ch). """ + if self._backend == 'cuda_kernel': + return self._compute_cuda(complex_signal, n_samp, transpose_axes) + elif self._backend == 'torch': + return self._compute_torch(complex_signal, n_samp, transpose_axes) + elif self._backend == 'numba': + return self._compute_numba(complex_signal, n_samp, transpose_axes) return self._compute_numpy(complex_signal, n_samp, transpose_axes) - + + def _compute_cuda(self, complex_signal, n_samp, transpose_axes): + """CUDA kernel for PLV.""" + from .kernels.cuda_phase import plv_cuda + return plv_cuda(complex_signal) + def _compute_numpy(self, complex_signal: np.ndarray, n_samp: int, transpose_axes: tuple) -> np.ndarray: """NumPy implementation of PLV.""" @@ -60,3 +78,78 @@ def _compute_numpy(self, complex_signal: np.ndarray, n_samp: int, dphi = multiply_conjugate(c, s, transpose_axes=transpose_axes) con = abs(dphi) / n_samp return con + + def _compute_numba(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """ + Numba JIT implementation of PLV with parallel epoch processing. + + Fuses the 4 einsum operations of multiply_conjugate into a single + loop pass, avoiding intermediate tensor allocations. Uses prange + for parallelization across epochs. + + This is significantly faster than numpy for PLV because: + 1. Zero intermediate allocations (numpy creates 4 temporary tensors) + 2. Single-pass accumulation in CPU registers + 3. prange parallelizes across epochs + """ + phase = complex_signal / np.abs(complex_signal) + c = np.real(phase) + s = np.imag(phase) + return _plv_numba_kernel(c, s, n_samp) + + def _compute_torch(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """ + PyTorch implementation of PLV. + + Uses multiply_conjugate_torch to offload the 4 einsum operations + to GPU. The einsum contracts the time dimension, so the output is + directly (E, F, C, C) — no 5D intermediate tensor. + + MPS uses float32; CPU/CUDA uses float64. + """ + device = self._device + complex_type = torch.complex64 if device == 'mps' else torch.complex128 + + sig = torch.from_numpy(complex_signal).to(device=device, dtype=complex_type) + + # Normalize to unit magnitude (phase only) + phase = sig / torch.abs(sig) + c, s = phase.real, phase.imag + + # Cross-spectrum with time contraction: (E, F, C, C) + dphi = multiply_conjugate_torch(c, s) + + con = torch.abs(dphi) / n_samp + return con.cpu().numpy() + + +# Numba JIT kernel (module-level for caching) +if NUMBA_AVAILABLE: + @njit(parallel=True, cache=True) + def _plv_numba_kernel(c, s, n_samp): + """ + Fused PLV computation: multiply_conjugate + abs in a single pass. + + Computes |sum_t(z_i(t) * conj(z_j(t)))| / T for all (i,j) pairs, + where z = c + i*s (unit-magnitude phase signal). + + Parallelized over epochs with prange. + """ + n_ep, n_freq, n_ch, n_t = c.shape + con = np.zeros((n_ep, n_freq, n_ch, n_ch)) + + for e in prange(n_ep): + for f in range(n_freq): + for i in range(n_ch): + for j in range(n_ch): + re_sum = 0.0 + im_sum = 0.0 + for t in range(n_t): + # z_i * conj(z_j) = (c_i*c_j + s_i*s_j) + i*(s_i*c_j - c_i*s_j) + re_sum += c[e, f, i, t] * c[e, f, j, t] + s[e, f, i, t] * s[e, f, j, t] + im_sum += s[e, f, i, t] * c[e, f, j, t] - c[e, f, i, t] * s[e, f, j, t] + con[e, f, i, j] = np.sqrt(re_sum**2 + im_sum**2) / n_samp + + return con diff --git a/hypyp/sync/pow_corr.py b/hypyp/sync/pow_corr.py index 1cc2e6e..747dfaf 100644 --- a/hypyp/sync/pow_corr.py +++ b/hypyp/sync/pow_corr.py @@ -7,55 +7,72 @@ import numpy as np -from .base import BaseMetric +from .base import BaseMetric, TORCH_AVAILABLE, NUMBA_AVAILABLE + +if TORCH_AVAILABLE: + import torch + +if NUMBA_AVAILABLE: + from numba import njit, prange class PowCorr(BaseMetric): """ Power Correlation connectivity metric. - + Power Correlation measures the correlation between the power (squared amplitude) of two signals across time. - + Mathematical formulation: PowCorr = correlation(|X|², |Y|²) over time samples - + The implementation normalizes the power values by subtracting the mean and dividing by the product of standard deviations. - + References ---------- Colclough, G. L., Woolrich, M. W., Tewarie, P. K., Brookes, M. J., Quinn, A. J., & Smith, S. M. (2016). How reliable are MEG resting-state connectivity metrics? NeuroImage, 138, 284-293. """ - + name = "powcorr" - + def compute(self, complex_signal: np.ndarray, n_samp: int, transpose_axes: tuple) -> np.ndarray: """ Compute Power Correlation. - + Parameters ---------- complex_signal : np.ndarray Complex analytic signals with shape (n_epochs, n_freq, 2*n_channels, n_times). - + n_samp : int Number of time samples. - + transpose_axes : tuple Axes to transpose for matrix multiplication. - + Returns ------- con : np.ndarray Power Correlation connectivity matrix with shape (n_epoch, n_freq, 2*n_ch, 2*n_ch). """ + if self._backend == 'cuda_kernel': + return self._compute_cuda(complex_signal, n_samp, transpose_axes) + elif self._backend == 'torch': + return self._compute_torch(complex_signal, n_samp, transpose_axes) + elif self._backend == 'numba': + return self._compute_numba(complex_signal, n_samp, transpose_axes) return self._compute_numpy(complex_signal, n_samp, transpose_axes) - + + def _compute_cuda(self, complex_signal, n_samp, transpose_axes): + """CUDA kernel for Power Correlation.""" + from .kernels.cuda_amplitude import powcorr_cuda + return powcorr_cuda(complex_signal) + def _compute_numpy(self, complex_signal: np.ndarray, n_samp: int, transpose_axes: tuple) -> np.ndarray: """NumPy implementation of Power Correlation.""" @@ -66,3 +83,95 @@ def _compute_numpy(self, complex_signal: np.ndarray, n_samp: int, con = np.einsum('nilm,nimk->nilk', env, env.transpose(transpose_axes)) / \ np.sqrt(np.einsum('nil,nik->nilk', np.sum(env ** 2, axis=3), np.sum(env ** 2, axis=3))) return con + + def _compute_numba(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """ + Numba JIT implementation of Power Correlation. + + Fuses mean-centering, Pearson numerator, and denominator into a + single loop pass with parallel epoch processing. Zero intermediate + tensor allocations — accumulates in CPU registers. + """ + env = np.abs(complex_signal) ** 2 + return _powcorr_numba_kernel(env) + + def _compute_torch(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """ + PyTorch implementation of Power Correlation. + + Computes power as abs(sig)² then Pearson correlation entirely in + float32 (MPS) or float64 (CPU/CUDA). No complex arithmetic on GPU + after envelope extraction. + + Uses L2-normalization instead of separate numerator/denominator to + avoid float32 underflow: power values are already small (~|z|²), + and the denominator sqrt(sum(dev²_i) * sum(dev²_j)) involves 4th + powers that underflow in float32. + """ + device = self._device + float_type = torch.float32 if device == 'mps' else torch.float64 + complex_type = torch.complex64 if device == 'mps' else torch.complex128 + + sig = torch.from_numpy(complex_signal).to(device=device, dtype=complex_type) + # Power: |z|² via abs then square + env = (torch.abs(sig).to(dtype=float_type)) ** 2 + del sig # free complex tensor + + # Center the power + mu = torch.mean(env, dim=3, keepdim=True) + env = env - mu + + # L2-normalize per channel: env_hat = env / ||env|| + # Then dot(env_hat_i, env_hat_j) = Pearson correlation directly + norm = torch.sqrt(torch.sum(env ** 2, dim=3, keepdim=True)) + norm = torch.where(norm == 0, torch.ones_like(norm), norm) + env = env / norm + + con = torch.einsum('efit,efjt->efij', env, env) + return con.cpu().numpy() + + +# Numba JIT kernel (module-level for caching) +if NUMBA_AVAILABLE: + @njit(parallel=True, cache=True) + def _powcorr_numba_kernel(env): + """ + Fused Pearson correlation on power envelopes. + + For each (epoch, freq): + 1. Pre-compute per-channel mean and sum-of-squared-deviations + 2. Pearson correlation for upper triangle, copy by symmetry + """ + n_ep, n_freq, n_ch, n_t = env.shape + con = np.zeros((n_ep, n_freq, n_ch, n_ch)) + + for e in prange(n_ep): + for f in range(n_freq): + # Pre-compute mean and sum_sq per channel + mu = np.zeros(n_ch) + ss = np.zeros(n_ch) + for ch in range(n_ch): + s = 0.0 + for t in range(n_t): + s += env[e, f, ch, t] + mu[ch] = s / n_t + sq = 0.0 + for t in range(n_t): + d = env[e, f, ch, t] - mu[ch] + sq += d * d + ss[ch] = sq + + # Pearson correlation for upper triangle + for i in range(n_ch): + for j in range(i, n_ch): + num = 0.0 + for t in range(n_t): + num += (env[e, f, i, t] - mu[i]) * (env[e, f, j, t] - mu[j]) + denom = np.sqrt(ss[i] * ss[j]) + val = num / denom if denom > 0 else 0.0 + con[e, f, i, j] = val + con[e, f, j, i] = val # symmetry + + return con diff --git a/hypyp/sync/wpli.py b/hypyp/sync/wpli.py index eec7792..a76ced0 100644 --- a/hypyp/sync/wpli.py +++ b/hypyp/sync/wpli.py @@ -7,20 +7,26 @@ import numpy as np -from .base import BaseMetric, multiply_conjugate_time +from .base import BaseMetric, multiply_conjugate_time, TORCH_AVAILABLE, NUMBA_AVAILABLE + +if TORCH_AVAILABLE: + import torch + +if NUMBA_AVAILABLE: + from numba import njit, prange class WPLI(BaseMetric): """ Weighted Phase Lag Index (wPLI) connectivity metric. - + wPLI is a modification of PLI that weights the contribution of each phase difference by its distance from the real axis. This reduces sensitivity to noise-induced perturbations of small phase differences. - + Mathematical formulation: wPLI = |⟨|Im(XY*)| sign(Im(XY*))⟩| / ⟨|Im(XY*)|⟩ - + References ---------- Vinck, M., Oostenveld, R., van Wingerden, M., Battaglia, F., & Pennartz, @@ -28,40 +34,171 @@ class WPLI(BaseMetric): physiological data in the presence of volume-conduction, noise and sample-size bias. NeuroImage, 55(4), 1548-1565. """ - + name = "wpli" - + def compute(self, complex_signal: np.ndarray, n_samp: int, transpose_axes: tuple) -> np.ndarray: """ Compute Weighted Phase Lag Index. - + Parameters ---------- complex_signal : np.ndarray Complex analytic signals with shape (n_epochs, n_freq, 2*n_channels, n_times). - + n_samp : int Number of time samples. - + transpose_axes : tuple Axes to transpose for matrix multiplication. - + Returns ------- con : np.ndarray wPLI connectivity matrix with shape (n_epoch, n_freq, 2*n_ch, 2*n_ch). """ + if self._backend == 'metal': + return self._compute_metal(complex_signal, n_samp, transpose_axes) + elif self._backend == 'cuda_kernel': + return self._compute_cuda(complex_signal, n_samp, transpose_axes) + elif self._backend == 'torch': + return self._compute_torch(complex_signal, n_samp, transpose_axes) + elif self._backend == 'numba': + return self._compute_numba(complex_signal, n_samp, transpose_axes) return self._compute_numpy(complex_signal, n_samp, transpose_axes) - + + def _compute_metal(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """Metal compute shader for wPLI on Apple Silicon GPU.""" + from .kernels.metal_phase import wpli_metal + return wpli_metal(complex_signal) + + def _compute_cuda(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """CUDA kernel for wPLI on NVIDIA GPU.""" + from .kernels.cuda_phase import wpli_cuda + return wpli_cuda(complex_signal) + + def _compute_numba(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """ + Numba JIT implementation of wPLI with fused kernel. + + Computes Im(X_i * conj(X_j)) and accumulates |mean(Im)| / mean(|Im|) + directly in the inner loop. No 5D tensor. + Uses the simplification |Im|*sign(Im) = Im for the numerator. + + Note: wPLI uses the raw signal (not phase-normalized), unlike PLI. + """ + c = np.real(complex_signal) + s = np.imag(complex_signal) + return _wpli_numba_kernel(c, s) + def _compute_numpy(self, complex_signal: np.ndarray, n_samp: int, transpose_axes: tuple) -> np.ndarray: """NumPy implementation of Weighted Phase Lag Index.""" c = np.real(complex_signal) s = np.imag(complex_signal) dphi = multiply_conjugate_time(c, s, transpose_axes=transpose_axes) - con_num = np.abs(np.mean(np.abs(np.imag(dphi)) * np.sign(np.imag(dphi)), axis=4)) + # |Im(x)| * sign(Im(x)) = Im(x) for all real x + con_num = np.abs(np.mean(np.imag(dphi), axis=4)) con_den = np.mean(np.abs(np.imag(dphi)), axis=4) con_den = np.where(con_den == 0, 1, con_den) con = con_num / con_den return con + + def _compute_torch(self, complex_signal: np.ndarray, n_samp: int, + transpose_axes: tuple) -> np.ndarray: + """ + PyTorch implementation of Weighted Phase Lag Index. + + Chunks computation by (epoch, freq) to avoid materializing the full + 5D tensor ``(E, F, C, C, T)`` which exceeds MPS INT_MAX at high + channel counts. Each chunk ``(C, C, T)`` stays well under the limit. + + Computes Im(X_i * conj(X_j)) = s_i*c_j - c_i*s_j directly with + 2 real einsum instead of 4 complex einsum. Halves GPU memory per chunk. + Uses simplified numerator: |mean(Im)| instead of |mean(|Im|*sign(Im))|. + + MPS uses float32 precision; CPU/CUDA uses float64. + """ + device = self._device + float_type = torch.float32 if device == 'mps' else torch.float64 + complex_type = torch.complex64 if device == 'mps' else torch.complex128 + + sig = torch.from_numpy(complex_signal).to(device=device, dtype=complex_type) + n_epochs, n_freq, n_ch, n_times = sig.shape + c, s = sig.real, sig.imag + + con = torch.zeros((n_epochs, n_freq, n_ch, n_ch), + device=device, dtype=float_type) + + # Chunk by epoch — each chunk is (F, C, C, T), 5x fewer iterations + # than (epoch, freq) chunking. Falls back to double loop if chunk + # would exceed MPS INT_MAX. + chunk_elements = n_freq * n_ch * n_ch * n_times + if device == 'mps' and chunk_elements > 2_000_000_000: + # Fallback: (epoch, freq) chunking for very large configs + formula = 'it,jt->ijt' + for e in range(n_epochs): + for f in range(n_freq): + c_ef = c[e, f] + s_ef = s[e, f] + im_dphi = torch.einsum(formula, s_ef, c_ef) - \ + torch.einsum(formula, c_ef, s_ef) + con_num = torch.abs(torch.mean(im_dphi, dim=-1)) + con_den = torch.mean(torch.abs(im_dphi), dim=-1) + con_den = torch.where(con_den == 0, torch.ones_like(con_den), con_den) + con[e, f] = con_num / con_den + else: + # Fast path: epoch-only chunking + formula = 'fit,fjt->fijt' + for e in range(n_epochs): + c_e = c[e] # (F, C, T) + s_e = s[e] + im_dphi = torch.einsum(formula, s_e, c_e) - \ + torch.einsum(formula, c_e, s_e) + # |Im| * sign(Im) = Im + con_num = torch.abs(torch.mean(im_dphi, dim=-1)) + con_den = torch.mean(torch.abs(im_dphi), dim=-1) + con_den = torch.where(con_den == 0, torch.ones_like(con_den), con_den) + con[e] = con_num / con_den + + return con.cpu().numpy() + + +# Numba JIT kernel (module-level for caching) +if NUMBA_AVAILABLE: + @njit(parallel=True, cache=True) + def _wpli_numba_kernel(c, s): + """ + Fused wPLI: weighted sign of Im(cross-spectrum). + + wPLI = |mean_t(Im)| / mean_t(|Im|) + Uses the simplification |Im|*sign(Im) = Im. + + No 5D tensor — O(C²) memory instead of O(C² × T). + """ + n_ep, n_freq, n_ch, n_t = c.shape + con = np.zeros((n_ep, n_freq, n_ch, n_ch)) + + for e in prange(n_ep): + for f in range(n_freq): + for i in range(n_ch): + for j in range(i, n_ch): + im_sum = 0.0 + abs_sum = 0.0 + for t in range(n_t): + im = s[e, f, i, t] * c[e, f, j, t] \ + - c[e, f, i, t] * s[e, f, j, t] + im_sum += im + abs_sum += abs(im) + if abs_sum > 0: + val = abs(im_sum) / abs_sum + else: + val = 0.0 + con[e, f, i, j] = val + con[e, f, j, i] = val + + return con diff --git a/pyproject.toml b/pyproject.toml index 56dbf92..a49e604 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,8 @@ dependencies = [ shiny = ["shiny>=1.1.0"] numba = ["numba>=0.60.0"] torch = ["torch>=2.0.0"] +metal = ["pyobjc-framework-Metal>=10.0"] +cupy = ["cupy-cuda12x>=13.0.0"] [project.urls] Homepage = "https://github.com/ppsp-team/HyPyP" diff --git a/tests/test_sync.py b/tests/test_sync.py index 75b36fe..54d86f8 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -11,8 +11,13 @@ import pytest from hypyp.analyses import compute_sync +from hypyp.sync import get_metric from hypyp.sync.accorr import ACCorr -from hypyp.sync.base import BaseMetric, NUMBA_AVAILABLE, TORCH_AVAILABLE, MPS_AVAILABLE +from hypyp.sync.base import ( + BaseMetric, AUTO_PRIORITY, + NUMBA_AVAILABLE, TORCH_AVAILABLE, MPS_AVAILABLE, METAL_AVAILABLE, +) +from hypyp.sync.kernels import CUPY_AVAILABLE from tests.accorr_reference import accorr_reference @@ -115,6 +120,786 @@ def test_compute_sync_torch(self, complex_signal, complex_signal_raw): np.testing.assert_allclose(result, result_reference, rtol=1e-9, atol=1e-10) +class TestPLV: + """Tests for Phase Locking Value with all backends.""" + + TRANSPOSE_AXES = (0, 1, 3, 2) + MPS_TOL = 1e-5 # PLV uses smooth operations (sin, cos, abs) — tight tolerance + + def test_plv_shape(self, complex_signal): + """PLV output shape should match input dimensions.""" + from hypyp.sync.plv import PLV + n_samp = complex_signal.shape[3] + result = PLV().compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + n_epochs, n_freq, n_ch, _ = complex_signal.shape + assert result.shape == (n_epochs, n_freq, n_ch, n_ch) + + def test_plv_value_range(self, complex_signal): + """PLV values should be in [0, 1].""" + from hypyp.sync.plv import PLV + n_samp = complex_signal.shape[3] + result = PLV().compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + assert np.all(result >= -1e-10) and np.all(result <= 1 + 1e-10) + assert not np.any(np.isnan(result)) + + @pytest.mark.skipif(not NUMBA_AVAILABLE, reason="Numba not available") + def test_plv_numba_vs_numpy(self, complex_signal): + """Numba PLV should match numpy PLV exactly (both float64).""" + from hypyp.sync.plv import PLV + n_samp = complex_signal.shape[3] + result_np = PLV(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_numba = PLV(optimization='numba').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_numba, result_np, rtol=1e-9, atol=1e-10) + + @pytest.mark.skipif(not TORCH_AVAILABLE, reason="Torch not available") + def test_plv_torch_vs_numpy(self, complex_signal): + """Torch PLV should match numpy PLV within MPS tolerance.""" + from hypyp.sync.plv import PLV + n_samp = complex_signal.shape[3] + result_np = PLV(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + metric_torch = PLV(optimization='torch') + result_torch = metric_torch.compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + if metric_torch._device == 'mps': + np.testing.assert_allclose(result_torch, result_np, + rtol=self.MPS_TOL, atol=self.MPS_TOL) + else: + np.testing.assert_allclose(result_torch, result_np, rtol=1e-9, atol=1e-10) + + def test_plv_symmetry(self, complex_signal): + """PLV matrix should be symmetric (PLV(i,j) == PLV(j,i)).""" + from hypyp.sync.plv import PLV + n_samp = complex_signal.shape[3] + result = PLV().compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + for e in range(result.shape[0]): + for f in range(result.shape[1]): + np.testing.assert_allclose( + result[e, f], result[e, f].T, rtol=1e-10, atol=1e-12 + ) + + @pytest.mark.skipif(not METAL_AVAILABLE, reason="Metal not available") + def test_plv_metal_vs_numpy(self, complex_signal): + """Metal PLV should match numpy PLV within float32 tolerance.""" + from hypyp.sync.plv import PLV + n_samp = complex_signal.shape[3] + result_np = PLV(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_metal = PLV(optimization='metal').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_metal, result_np, rtol=1e-5, atol=1e-5) + + @pytest.mark.skipif(not CUPY_AVAILABLE, reason="CuPy not available") + def test_plv_cuda_vs_numpy(self, complex_signal): + """CUDA PLV should match numpy PLV exactly (both float64).""" + from hypyp.sync.plv import PLV + n_samp = complex_signal.shape[3] + result_np = PLV(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_cuda = PLV(optimization='cuda_kernel').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_cuda, result_np, rtol=1e-9, atol=1e-10) + + +class TestCCorr: + """Tests for circular correlation metric.""" + + TRANSPOSE_AXES = (0, 1, 3, 2) + + def test_ccorr_shape(self, complex_signal): + """CCorr output shape should match input dimensions.""" + from hypyp.sync.ccorr import CCorr + metric = CCorr() + n_samp = complex_signal.shape[3] + result = metric.compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + n_epochs, n_freq, n_ch, _ = complex_signal.shape + assert result.shape == (n_epochs, n_freq, n_ch, n_ch) + + def test_ccorr_value_range(self, complex_signal): + """CCorr values should be non-negative (abs of correlation).""" + from hypyp.sync.ccorr import CCorr + metric = CCorr() + n_samp = complex_signal.shape[3] + result = metric.compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + assert np.all(result >= -1e-10) + assert not np.any(np.isnan(result)) + + def test_ccorr_vs_scipy_reference(self, complex_signal): + """New inline circmean should match scipy.stats.circmean exactly.""" + from scipy.stats import circmean + from hypyp.sync.ccorr import CCorr + + # Compute with new implementation + metric = CCorr() + n_samp = complex_signal.shape[3] + result_new = metric.compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + + # Compute reference using scipy circmean + n_epoch, n_freq, n_ch_total = complex_signal.shape[:3] + angle = np.angle(complex_signal) + mu_angle_scipy = circmean(angle, high=np.pi, low=-np.pi, axis=3).reshape( + n_epoch, n_freq, n_ch_total, 1 + ) + angle_centered = np.sin(angle - mu_angle_scipy) + formula = 'nilm,nimk->nilk' + transpose_axes = self.TRANSPOSE_AXES + result_scipy = np.abs( + np.einsum(formula, angle_centered, angle_centered.transpose(transpose_axes)) / + np.sqrt(np.einsum('nil,nik->nilk', + np.sum(angle_centered ** 2, axis=3), + np.sum(angle_centered ** 2, axis=3))) + ) + + np.testing.assert_allclose(result_new, result_scipy, rtol=1e-12, atol=1e-14) + + def test_ccorr_symmetry(self, complex_signal): + """CCorr matrix should be symmetric for each epoch/freq.""" + from hypyp.sync.ccorr import CCorr + metric = CCorr() + n_samp = complex_signal.shape[3] + result = metric.compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + for e in range(result.shape[0]): + for f in range(result.shape[1]): + np.testing.assert_allclose( + result[e, f], result[e, f].T, rtol=1e-10, atol=1e-12 + ) + + + @pytest.mark.skipif(not NUMBA_AVAILABLE, reason="Numba not available") + def test_ccorr_numba_vs_numpy(self, complex_signal): + """Numba CCorr should match numpy CCorr exactly (both float64).""" + from hypyp.sync.ccorr import CCorr + n_samp = complex_signal.shape[3] + result_np = CCorr(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_numba = CCorr(optimization='numba').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_numba, result_np, rtol=1e-9, atol=1e-10) + + @pytest.mark.skipif(not TORCH_AVAILABLE, reason="Torch not available") + def test_ccorr_torch_vs_numpy(self, complex_signal): + """Torch CCorr should match numpy CCorr within MPS tolerance.""" + from hypyp.sync.ccorr import CCorr + n_samp = complex_signal.shape[3] + result_np = CCorr(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + metric_torch = CCorr(optimization='torch') + result_torch = metric_torch.compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + # Angle-free reformulation eliminates transcendental function chain, + # bringing MPS precision in line with PLV. + if metric_torch._device == 'mps': + np.testing.assert_allclose(result_torch, result_np, rtol=1e-5, atol=1e-5) + else: + np.testing.assert_allclose(result_torch, result_np, rtol=1e-9, atol=1e-10) + + @pytest.mark.skipif(not METAL_AVAILABLE, reason="Metal not available") + def test_ccorr_metal_vs_numpy(self, complex_signal): + """Metal CCorr should match numpy CCorr within float32 tolerance. + + Uses Kahan summation with fastMath=OFF to preserve IEEE-754 compliance. + """ + from hypyp.sync.ccorr import CCorr + n_samp = complex_signal.shape[3] + result_np = CCorr(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_metal = CCorr(optimization='metal').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_metal, result_np, rtol=1e-5, atol=1e-5) + + @pytest.mark.skipif(not CUPY_AVAILABLE, reason="CuPy not available") + def test_ccorr_cuda_vs_numpy(self, complex_signal): + """CUDA CCorr should match numpy CCorr exactly (both float64).""" + from hypyp.sync.ccorr import CCorr + n_samp = complex_signal.shape[3] + result_np = CCorr(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_cuda = CCorr(optimization='cuda_kernel').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_cuda, result_np, rtol=1e-9, atol=1e-10) + + +class TestCoh: + """Tests for Coherence with all backends.""" + + TRANSPOSE_AXES = (0, 1, 3, 2) + MPS_TOL = 1e-5 + + def test_coh_shape(self, complex_signal): + """Coh output shape should match input dimensions.""" + from hypyp.sync.coh import Coh + n_samp = complex_signal.shape[3] + result = Coh().compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + n_epochs, n_freq, n_ch, _ = complex_signal.shape + assert result.shape == (n_epochs, n_freq, n_ch, n_ch) + + def test_coh_value_range(self, complex_signal): + """Coh values should be in [0, 1].""" + from hypyp.sync.coh import Coh + n_samp = complex_signal.shape[3] + result = Coh().compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + assert np.all(result >= -1e-10) and np.all(result <= 1 + 1e-10) + assert not np.any(np.isnan(result)) + + @pytest.mark.skipif(not NUMBA_AVAILABLE, reason="Numba not available") + def test_coh_numba_vs_numpy(self, complex_signal): + """Numba Coh should match numpy Coh exactly (both float64).""" + from hypyp.sync.coh import Coh + n_samp = complex_signal.shape[3] + result_np = Coh(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_numba = Coh(optimization='numba').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_numba, result_np, rtol=1e-9, atol=1e-10) + + @pytest.mark.skipif(not TORCH_AVAILABLE, reason="Torch not available") + def test_coh_torch_vs_numpy(self, complex_signal): + """Torch Coh should match numpy Coh within MPS tolerance.""" + from hypyp.sync.coh import Coh + n_samp = complex_signal.shape[3] + result_np = Coh(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + metric_torch = Coh(optimization='torch') + result_torch = metric_torch.compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + if metric_torch._device == 'mps': + np.testing.assert_allclose(result_torch, result_np, + rtol=self.MPS_TOL, atol=self.MPS_TOL) + else: + np.testing.assert_allclose(result_torch, result_np, rtol=1e-9, atol=1e-10) + + def test_coh_symmetry(self, complex_signal): + """Coh matrix should be symmetric.""" + from hypyp.sync.coh import Coh + n_samp = complex_signal.shape[3] + result = Coh().compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + for e in range(result.shape[0]): + for f in range(result.shape[1]): + np.testing.assert_allclose( + result[e, f], result[e, f].T, rtol=1e-10, atol=1e-12 + ) + + @pytest.mark.skipif(not METAL_AVAILABLE, reason="Metal not available") + def test_coh_metal_vs_numpy(self, complex_signal): + """Metal Coh should match numpy Coh within float32 tolerance.""" + from hypyp.sync.coh import Coh + n_samp = complex_signal.shape[3] + result_np = Coh(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_metal = Coh(optimization='metal').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_metal, result_np, rtol=1e-5, atol=1e-5) + + @pytest.mark.skipif(not CUPY_AVAILABLE, reason="CuPy not available") + def test_coh_cuda_vs_numpy(self, complex_signal): + """CUDA Coh should match numpy Coh exactly (both float64).""" + from hypyp.sync.coh import Coh + n_samp = complex_signal.shape[3] + result_np = Coh(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_cuda = Coh(optimization='cuda_kernel').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_cuda, result_np, rtol=1e-9, atol=1e-10) + + +class TestImCoh: + """Tests for Imaginary Coherence with all backends.""" + + TRANSPOSE_AXES = (0, 1, 3, 2) + MPS_TOL = 1e-5 + + def test_imcoh_shape(self, complex_signal): + """ImCoh output shape should match input dimensions.""" + from hypyp.sync.imaginary_coh import ImCoh + n_samp = complex_signal.shape[3] + result = ImCoh().compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + n_epochs, n_freq, n_ch, _ = complex_signal.shape + assert result.shape == (n_epochs, n_freq, n_ch, n_ch) + + def test_imcoh_value_range(self, complex_signal): + """ImCoh values should be in [0, 1].""" + from hypyp.sync.imaginary_coh import ImCoh + n_samp = complex_signal.shape[3] + result = ImCoh().compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + assert np.all(result >= -1e-10) and np.all(result <= 1 + 1e-10) + assert not np.any(np.isnan(result)) + + @pytest.mark.skipif(not NUMBA_AVAILABLE, reason="Numba not available") + def test_imcoh_numba_vs_numpy(self, complex_signal): + """Numba ImCoh should match numpy ImCoh exactly (both float64).""" + from hypyp.sync.imaginary_coh import ImCoh + n_samp = complex_signal.shape[3] + result_np = ImCoh(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_numba = ImCoh(optimization='numba').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_numba, result_np, rtol=1e-9, atol=1e-10) + + @pytest.mark.skipif(not TORCH_AVAILABLE, reason="Torch not available") + def test_imcoh_torch_vs_numpy(self, complex_signal): + """Torch ImCoh should match numpy ImCoh within MPS tolerance.""" + from hypyp.sync.imaginary_coh import ImCoh + n_samp = complex_signal.shape[3] + result_np = ImCoh(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + metric_torch = ImCoh(optimization='torch') + result_torch = metric_torch.compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + if metric_torch._device == 'mps': + np.testing.assert_allclose(result_torch, result_np, + rtol=self.MPS_TOL, atol=self.MPS_TOL) + else: + np.testing.assert_allclose(result_torch, result_np, rtol=1e-9, atol=1e-10) + + def test_imcoh_symmetry(self, complex_signal): + """ImCoh matrix should be symmetric.""" + from hypyp.sync.imaginary_coh import ImCoh + n_samp = complex_signal.shape[3] + result = ImCoh().compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + for e in range(result.shape[0]): + for f in range(result.shape[1]): + np.testing.assert_allclose( + result[e, f], result[e, f].T, rtol=1e-10, atol=1e-12 + ) + + @pytest.mark.skipif(not METAL_AVAILABLE, reason="Metal not available") + def test_imcoh_metal_vs_numpy(self, complex_signal): + """Metal ImCoh should match numpy ImCoh within float32 tolerance.""" + from hypyp.sync.imaginary_coh import ImCoh + n_samp = complex_signal.shape[3] + result_np = ImCoh(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_metal = ImCoh(optimization='metal').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_metal, result_np, rtol=1e-5, atol=1e-5) + + @pytest.mark.skipif(not CUPY_AVAILABLE, reason="CuPy not available") + def test_imcoh_cuda_vs_numpy(self, complex_signal): + """CUDA ImCoh should match numpy ImCoh exactly (both float64).""" + from hypyp.sync.imaginary_coh import ImCoh + n_samp = complex_signal.shape[3] + result_np = ImCoh(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_cuda = ImCoh(optimization='cuda_kernel').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_cuda, result_np, rtol=1e-9, atol=1e-10) + + +class TestEnvCorr: + """Tests for Envelope Correlation with all backends.""" + + TRANSPOSE_AXES = (0, 1, 3, 2) + MPS_TOL = 1e-5 + + def test_envcorr_shape(self, complex_signal): + """EnvCorr output shape should match input dimensions.""" + from hypyp.sync.envelope_corr import EnvCorr + n_samp = complex_signal.shape[3] + result = EnvCorr().compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + n_epochs, n_freq, n_ch, _ = complex_signal.shape + assert result.shape == (n_epochs, n_freq, n_ch, n_ch) + + def test_envcorr_value_range(self, complex_signal): + """EnvCorr values should be in [-1, 1].""" + from hypyp.sync.envelope_corr import EnvCorr + n_samp = complex_signal.shape[3] + result = EnvCorr().compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + assert np.all(result >= -1 - 1e-10) and np.all(result <= 1 + 1e-10) + assert not np.any(np.isnan(result)) + + @pytest.mark.skipif(not NUMBA_AVAILABLE, reason="Numba not available") + def test_envcorr_numba_vs_numpy(self, complex_signal): + """Numba EnvCorr should match numpy exactly (both float64).""" + from hypyp.sync.envelope_corr import EnvCorr + n_samp = complex_signal.shape[3] + result_np = EnvCorr(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_numba = EnvCorr(optimization='numba').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_numba, result_np, rtol=1e-9, atol=1e-10) + + @pytest.mark.skipif(not TORCH_AVAILABLE, reason="Torch not available") + def test_envcorr_torch_vs_numpy(self, complex_signal): + """Torch EnvCorr should match numpy within MPS tolerance.""" + from hypyp.sync.envelope_corr import EnvCorr + n_samp = complex_signal.shape[3] + result_np = EnvCorr(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + metric_torch = EnvCorr(optimization='torch') + result_torch = metric_torch.compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + if metric_torch._device == 'mps': + np.testing.assert_allclose(result_torch, result_np, + rtol=self.MPS_TOL, atol=self.MPS_TOL) + else: + np.testing.assert_allclose(result_torch, result_np, rtol=1e-9, atol=1e-10) + + def test_envcorr_symmetry(self, complex_signal): + """EnvCorr matrix should be symmetric.""" + from hypyp.sync.envelope_corr import EnvCorr + n_samp = complex_signal.shape[3] + result = EnvCorr().compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + for e in range(result.shape[0]): + for f in range(result.shape[1]): + np.testing.assert_allclose( + result[e, f], result[e, f].T, rtol=1e-10, atol=1e-12 + ) + + +class TestPLI: + """Tests for Phase Lag Index with torch backend.""" + + TRANSPOSE_AXES = (0, 1, 3, 2) + # PLI uses sign() which is discontinuous at zero. MPS float32 can round + # imaginary parts near zero differently than float64, flipping the sign + # for a tiny fraction of values. A looser tolerance is needed. + MPS_TOL = 1e-2 + + def test_pli_shape(self, complex_signal): + """PLI output shape should match input dimensions.""" + from hypyp.sync.pli import PLI + metric = PLI() + n_samp = complex_signal.shape[3] + result = metric.compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + n_epochs, n_freq, n_ch, _ = complex_signal.shape + assert result.shape == (n_epochs, n_freq, n_ch, n_ch) + + def test_pli_value_range(self, complex_signal): + """PLI values should be in [0, 1].""" + from hypyp.sync.pli import PLI + metric = PLI() + n_samp = complex_signal.shape[3] + result = metric.compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + assert np.all(result >= -1e-10) and np.all(result <= 1 + 1e-10) + assert not np.any(np.isnan(result)) + + @pytest.mark.skipif(not NUMBA_AVAILABLE, reason="Numba not available") + def test_pli_numba_vs_numpy(self, complex_signal): + """Numba PLI should match numpy PLI exactly (both float64).""" + from hypyp.sync.pli import PLI + n_samp = complex_signal.shape[3] + result_np = PLI(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_numba = PLI(optimization='numba').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_numba, result_np, rtol=1e-9, atol=1e-10) + + @pytest.mark.skipif(not TORCH_AVAILABLE, reason="Torch not available") + def test_pli_torch_vs_numpy(self, complex_signal): + """Torch PLI should match numpy PLI.""" + from hypyp.sync.pli import PLI + n_samp = complex_signal.shape[3] + + result_np = PLI(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + metric_torch = PLI(optimization='torch') + result_torch = metric_torch.compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + + if metric_torch._device == 'mps': + np.testing.assert_allclose(result_torch, result_np, + rtol=self.MPS_TOL, atol=self.MPS_TOL) + else: + np.testing.assert_allclose(result_torch, result_np, rtol=1e-9, atol=1e-10) + + @pytest.mark.skipif(not TORCH_AVAILABLE, reason="Torch not available") + def test_pli_torch_shape(self, complex_signal): + """Torch PLI output shape should match numpy.""" + from hypyp.sync.pli import PLI + n_samp = complex_signal.shape[3] + result = PLI(optimization='torch').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + n_epochs, n_freq, n_ch, _ = complex_signal.shape + assert result.shape == (n_epochs, n_freq, n_ch, n_ch) + + + @pytest.mark.skipif(not TORCH_AVAILABLE, reason="Torch not available") + def test_pli_torch_large_channels(self): + """PLI torch should handle 128ch/subject (256 total) without MPS INT_MAX crash.""" + from hypyp.sync.pli import PLI + + rng = np.random.default_rng(42) + n_ch_per_subject = 128 + sig = rng.standard_normal((2, 1, 2 * n_ch_per_subject, 256)) + \ + 1j * rng.standard_normal((2, 1, 2 * n_ch_per_subject, 256)) + n_samp = sig.shape[3] + + result_np = PLI().compute(sig, n_samp, self.TRANSPOSE_AXES) + result_torch = PLI(optimization='torch').compute(sig, n_samp, self.TRANSPOSE_AXES) + + assert result_torch.shape == result_np.shape + assert not np.any(np.isnan(result_torch)) + + @pytest.mark.skipif(not METAL_AVAILABLE, reason="Metal not available") + def test_pli_metal_vs_numpy(self, complex_signal): + """Metal PLI should match numpy PLI within float32 tolerance.""" + from hypyp.sync.pli import PLI + n_samp = complex_signal.shape[3] + result_np = PLI(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_metal = PLI(optimization='metal').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + # Float32 precision — sign() near zero can flip + np.testing.assert_allclose(result_metal, result_np, rtol=1e-2, atol=1e-2) + + @pytest.mark.skipif(not METAL_AVAILABLE, reason="Metal not available") + def test_pli_metal_large_channels(self): + """Metal PLI should handle 128ch/subject (256 total).""" + from hypyp.sync.pli import PLI + rng = np.random.default_rng(42) + sig = rng.standard_normal((2, 1, 256, 256)) + 1j * rng.standard_normal((2, 1, 256, 256)) + n_samp = sig.shape[3] + result = PLI(optimization='metal').compute(sig, n_samp, self.TRANSPOSE_AXES) + assert result.shape == (2, 1, 256, 256) + assert not np.any(np.isnan(result)) + assert np.allclose(np.diagonal(result[0, 0]), 0) # diagonal = 0 + + @pytest.mark.skipif(not CUPY_AVAILABLE, reason="CuPy not available") + def test_pli_cuda_vs_numpy(self, complex_signal): + """CUDA PLI should match numpy PLI exactly (both float64).""" + from hypyp.sync.pli import PLI + n_samp = complex_signal.shape[3] + result_np = PLI(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_cuda = PLI(optimization='cuda_kernel').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + # Float64: should match to machine precision + np.testing.assert_allclose(result_cuda, result_np, rtol=1e-9, atol=1e-10) + + +class TestWPLI: + """Tests for Weighted Phase Lag Index with torch backend.""" + + TRANSPOSE_AXES = (0, 1, 3, 2) + # Same sign() precision issue as PLI, though wPLI weights mitigate it somewhat + MPS_TOL = 1e-2 + + def test_wpli_shape(self, complex_signal): + """wPLI output shape should match input dimensions.""" + from hypyp.sync.wpli import WPLI + metric = WPLI() + n_samp = complex_signal.shape[3] + result = metric.compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + n_epochs, n_freq, n_ch, _ = complex_signal.shape + assert result.shape == (n_epochs, n_freq, n_ch, n_ch) + + def test_wpli_value_range(self, complex_signal): + """wPLI values should be in [0, 1].""" + from hypyp.sync.wpli import WPLI + metric = WPLI() + n_samp = complex_signal.shape[3] + result = metric.compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + assert np.all(result >= -1e-10) and np.all(result <= 1 + 1e-10) + assert not np.any(np.isnan(result)) + + @pytest.mark.skipif(not TORCH_AVAILABLE, reason="Torch not available") + @pytest.mark.skipif(not NUMBA_AVAILABLE, reason="Numba not available") + def test_wpli_numba_vs_numpy(self, complex_signal): + """Numba wPLI should match numpy wPLI exactly (both float64).""" + from hypyp.sync.wpli import WPLI + n_samp = complex_signal.shape[3] + result_np = WPLI(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_numba = WPLI(optimization='numba').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_numba, result_np, rtol=1e-9, atol=1e-10) + + @pytest.mark.skipif(not TORCH_AVAILABLE, reason="Torch not available") + def test_wpli_torch_vs_numpy(self, complex_signal): + """Torch wPLI should match numpy wPLI.""" + from hypyp.sync.wpli import WPLI + n_samp = complex_signal.shape[3] + + result_np = WPLI(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + metric_torch = WPLI(optimization='torch') + result_torch = metric_torch.compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + + if metric_torch._device == 'mps': + np.testing.assert_allclose(result_torch, result_np, + rtol=self.MPS_TOL, atol=self.MPS_TOL) + else: + np.testing.assert_allclose(result_torch, result_np, rtol=1e-9, atol=1e-10) + + @pytest.mark.skipif(not TORCH_AVAILABLE, reason="Torch not available") + def test_wpli_torch_shape(self, complex_signal): + """Torch wPLI output shape should match numpy.""" + from hypyp.sync.wpli import WPLI + n_samp = complex_signal.shape[3] + result = WPLI(optimization='torch').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + n_epochs, n_freq, n_ch, _ = complex_signal.shape + assert result.shape == (n_epochs, n_freq, n_ch, n_ch) + + + @pytest.mark.skipif(not TORCH_AVAILABLE, reason="Torch not available") + def test_wpli_torch_large_channels(self): + """wPLI torch should handle 128ch/subject (256 total) without MPS INT_MAX crash.""" + from hypyp.sync.wpli import WPLI + + rng = np.random.default_rng(42) + n_ch_per_subject = 128 + sig = rng.standard_normal((2, 1, 2 * n_ch_per_subject, 256)) + \ + 1j * rng.standard_normal((2, 1, 2 * n_ch_per_subject, 256)) + n_samp = sig.shape[3] + + result_np = WPLI().compute(sig, n_samp, self.TRANSPOSE_AXES) + result_torch = WPLI(optimization='torch').compute(sig, n_samp, self.TRANSPOSE_AXES) + + assert result_torch.shape == result_np.shape + assert not np.any(np.isnan(result_torch)) + + +class TestEnvCorr: + """Tests for Envelope Correlation with all backends.""" + + TRANSPOSE_AXES = (0, 1, 3, 2) + MPS_TOL = 1e-5 + + def test_envcorr_shape(self, complex_signal): + """EnvCorr output shape should match input dimensions.""" + from hypyp.sync.envelope_corr import EnvCorr + n_samp = complex_signal.shape[3] + result = EnvCorr().compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + n_epochs, n_freq, n_ch, _ = complex_signal.shape + assert result.shape == (n_epochs, n_freq, n_ch, n_ch) + + def test_envcorr_value_range(self, complex_signal): + """EnvCorr values should be in [-1, 1] (Pearson correlation).""" + from hypyp.sync.envelope_corr import EnvCorr + n_samp = complex_signal.shape[3] + result = EnvCorr().compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + assert np.all(result >= -1 - 1e-10) and np.all(result <= 1 + 1e-10) + assert not np.any(np.isnan(result)) + + def test_envcorr_symmetry(self, complex_signal): + """EnvCorr matrix should be symmetric.""" + from hypyp.sync.envelope_corr import EnvCorr + n_samp = complex_signal.shape[3] + result = EnvCorr().compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + for e in range(result.shape[0]): + for f in range(result.shape[1]): + np.testing.assert_allclose( + result[e, f], result[e, f].T, rtol=1e-10, atol=1e-12 + ) + + @pytest.mark.skipif(not NUMBA_AVAILABLE, reason="Numba not available") + def test_envcorr_numba_vs_numpy(self, complex_signal): + """Numba EnvCorr should match numpy EnvCorr exactly (both float64).""" + from hypyp.sync.envelope_corr import EnvCorr + n_samp = complex_signal.shape[3] + result_np = EnvCorr(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_numba = EnvCorr(optimization='numba').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_numba, result_np, rtol=1e-9, atol=1e-10) + + @pytest.mark.skipif(not TORCH_AVAILABLE, reason="Torch not available") + def test_envcorr_torch_vs_numpy(self, complex_signal): + """Torch EnvCorr should match numpy EnvCorr within MPS tolerance.""" + from hypyp.sync.envelope_corr import EnvCorr + n_samp = complex_signal.shape[3] + result_np = EnvCorr(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + metric_torch = EnvCorr(optimization='torch') + result_torch = metric_torch.compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + if metric_torch._device == 'mps': + np.testing.assert_allclose(result_torch, result_np, + rtol=self.MPS_TOL, atol=self.MPS_TOL) + else: + np.testing.assert_allclose(result_torch, result_np, rtol=1e-9, atol=1e-10) + + @pytest.mark.skipif(not METAL_AVAILABLE, reason="Metal not available") + def test_envcorr_metal_vs_numpy(self, complex_signal): + """Metal EnvCorr should match numpy EnvCorr within float32 tolerance.""" + from hypyp.sync.envelope_corr import EnvCorr + n_samp = complex_signal.shape[3] + result_np = EnvCorr(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_metal = EnvCorr(optimization='metal').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_metal, result_np, rtol=1e-5, atol=1e-5) + + @pytest.mark.skipif(not CUPY_AVAILABLE, reason="CuPy not available") + def test_envcorr_cuda_vs_numpy(self, complex_signal): + """CUDA EnvCorr should match numpy EnvCorr exactly (both float64).""" + from hypyp.sync.envelope_corr import EnvCorr + n_samp = complex_signal.shape[3] + result_np = EnvCorr(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_cuda = EnvCorr(optimization='cuda_kernel').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_cuda, result_np, rtol=1e-9, atol=1e-10) + + +class TestPowCorr: + """Tests for Power Correlation with all backends.""" + + TRANSPOSE_AXES = (0, 1, 3, 2) + MPS_TOL = 1e-5 + + def test_powcorr_shape(self, complex_signal): + """PowCorr output shape should match input dimensions.""" + from hypyp.sync.pow_corr import PowCorr + n_samp = complex_signal.shape[3] + result = PowCorr().compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + n_epochs, n_freq, n_ch, _ = complex_signal.shape + assert result.shape == (n_epochs, n_freq, n_ch, n_ch) + + def test_powcorr_value_range(self, complex_signal): + """PowCorr values should be in [-1, 1] (Pearson correlation).""" + from hypyp.sync.pow_corr import PowCorr + n_samp = complex_signal.shape[3] + result = PowCorr().compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + assert np.all(result >= -1 - 1e-10) and np.all(result <= 1 + 1e-10) + assert not np.any(np.isnan(result)) + + def test_powcorr_symmetry(self, complex_signal): + """PowCorr matrix should be symmetric.""" + from hypyp.sync.pow_corr import PowCorr + n_samp = complex_signal.shape[3] + result = PowCorr().compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + for e in range(result.shape[0]): + for f in range(result.shape[1]): + np.testing.assert_allclose( + result[e, f], result[e, f].T, rtol=1e-10, atol=1e-12 + ) + + @pytest.mark.skipif(not NUMBA_AVAILABLE, reason="Numba not available") + def test_powcorr_numba_vs_numpy(self, complex_signal): + """Numba PowCorr should match numpy PowCorr exactly (both float64).""" + from hypyp.sync.pow_corr import PowCorr + n_samp = complex_signal.shape[3] + result_np = PowCorr(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_numba = PowCorr(optimization='numba').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_numba, result_np, rtol=1e-9, atol=1e-10) + + @pytest.mark.skipif(not TORCH_AVAILABLE, reason="Torch not available") + def test_powcorr_torch_vs_numpy(self, complex_signal): + """Torch PowCorr should match numpy PowCorr within MPS tolerance.""" + from hypyp.sync.pow_corr import PowCorr + n_samp = complex_signal.shape[3] + result_np = PowCorr(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + metric_torch = PowCorr(optimization='torch') + result_torch = metric_torch.compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + if metric_torch._device == 'mps': + np.testing.assert_allclose(result_torch, result_np, + rtol=self.MPS_TOL, atol=self.MPS_TOL) + else: + np.testing.assert_allclose(result_torch, result_np, rtol=1e-9, atol=1e-10) + + + @pytest.mark.skipif(not METAL_AVAILABLE, reason="Metal not available") + def test_powcorr_metal_vs_numpy(self, complex_signal): + """Metal PowCorr should match numpy PowCorr within float32 tolerance.""" + from hypyp.sync.pow_corr import PowCorr + n_samp = complex_signal.shape[3] + result_np = PowCorr(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_metal = PowCorr(optimization='metal').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_metal, result_np, rtol=1e-5, atol=1e-5) + + @pytest.mark.skipif(not CUPY_AVAILABLE, reason="CuPy not available") + def test_powcorr_cuda_vs_numpy(self, complex_signal): + """CUDA PowCorr should match numpy PowCorr exactly (both float64).""" + from hypyp.sync.pow_corr import PowCorr + n_samp = complex_signal.shape[3] + result_np = PowCorr(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_cuda = PowCorr(optimization='cuda_kernel').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_cuda, result_np, rtol=1e-9, atol=1e-10) + + @pytest.mark.skipif(not METAL_AVAILABLE, reason="Metal not available") + def test_wpli_metal_vs_numpy(self, complex_signal): + """Metal wPLI should match numpy wPLI within float32 tolerance.""" + from hypyp.sync.wpli import WPLI + n_samp = complex_signal.shape[3] + result_np = WPLI(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_metal = WPLI(optimization='metal').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_metal, result_np, rtol=1e-2, atol=1e-2) + + @pytest.mark.skipif(not CUPY_AVAILABLE, reason="CuPy not available") + def test_wpli_cuda_vs_numpy(self, complex_signal): + """CUDA wPLI should match numpy wPLI exactly (both float64).""" + from hypyp.sync.wpli import WPLI + n_samp = complex_signal.shape[3] + result_np = WPLI(optimization=None).compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + result_cuda = WPLI(optimization='cuda_kernel').compute(complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_cuda, result_np, rtol=1e-9, atol=1e-10) + + +class TestAccorrKernels: + """Tests for ACCorr Metal and CUDA kernels.""" + + TRANSPOSE_AXES = (0, 1, 3, 2) + + @pytest.mark.skipif(not METAL_AVAILABLE, reason="Metal not available") + def test_accorr_metal_vs_numpy(self, complex_signal): + """Metal ACCorr should match numpy within float32 tolerance.""" + from hypyp.sync.accorr import ACCorr + n_samp = complex_signal.shape[3] + result_np = ACCorr(optimization=None, show_progress=False).compute( + complex_signal, n_samp, self.TRANSPOSE_AXES) + result_metal = ACCorr(optimization='metal', show_progress=False).compute( + complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_metal, result_np, rtol=1e-5, atol=1e-5) + + @pytest.mark.skipif(not CUPY_AVAILABLE, reason="CuPy not available") + def test_accorr_cuda_vs_numpy(self, complex_signal): + """CUDA ACCorr should match numpy exactly (both float64).""" + from hypyp.sync.accorr import ACCorr + n_samp = complex_signal.shape[3] + result_np = ACCorr(optimization=None, show_progress=False).compute( + complex_signal, n_samp, self.TRANSPOSE_AXES) + result_cuda = ACCorr(optimization='cuda_kernel', show_progress=False).compute( + complex_signal, n_samp, self.TRANSPOSE_AXES) + np.testing.assert_allclose(result_cuda, result_np, rtol=1e-9, atol=1e-10) + + class TestAccorrErrorHandling: """Error handling and fallback behavior.""" @@ -139,4 +924,67 @@ def test_torch_fallback_warning(self): def test_auto_resolves(self): """optimization='auto' should resolve without error.""" metric = ACCorr(optimization='auto') - assert metric._backend in ('numpy', 'numba', 'torch') + assert metric._backend in ('numpy', 'numba', 'torch', 'metal', 'cuda_kernel') + + +class TestAutoDispatch: + """Benchmark-driven 'auto' dispatch per metric and platform.""" + + def test_auto_all_metrics_resolve(self): + """optimization='auto' resolves for every metric without error.""" + for metric_name in AUTO_PRIORITY: + m = get_metric(metric_name, optimization='auto') + assert m._backend in ('numpy', 'numba', 'torch', 'metal', 'cuda_kernel') + + @pytest.mark.skipif(not MPS_AVAILABLE, reason="MPS not available") + def test_auto_einsum_prefers_torch_on_mps(self): + """Einsum metrics should prefer torch on Apple Silicon.""" + for metric_name in ['plv', 'ccorr', 'coh', 'imcoh', 'envcorr', 'powcorr']: + m = get_metric(metric_name, optimization='auto') + assert m._backend == 'torch' and m._device == 'mps', ( + f"{metric_name} auto: expected torch/mps, got {m._backend}/{m._device}" + ) + + @pytest.mark.skipif(not MPS_AVAILABLE, reason="MPS not available") + @pytest.mark.skipif(not METAL_AVAILABLE, reason="Metal not available") + def test_auto_sign_prefers_metal_on_mps(self): + """PLI/wPLI/ACCorr should prefer Metal on Apple Silicon.""" + for metric_name in ['pli', 'wpli', 'accorr']: + m = get_metric(metric_name, optimization='auto') + assert m._backend == 'metal', ( + f"{metric_name} auto: expected metal, got {m._backend}" + ) + + def test_auto_priority_override(self): + """Custom priority overrides the AUTO_PRIORITY table.""" + m = get_metric('plv', optimization='auto', priority=['numba']) + if NUMBA_AVAILABLE: + assert m._backend == 'numba' + else: + assert m._backend == 'numpy' + + def test_auto_priority_skips_unavailable(self): + """Priority list gracefully skips unavailable backends.""" + with patch('hypyp.sync.base.METAL_AVAILABLE', False), \ + patch('hypyp.sync.base.CUPY_AVAILABLE', False): + m = get_metric('pli', optimization='auto', priority=['metal', 'cuda_kernel', 'numba']) + if NUMBA_AVAILABLE: + assert m._backend == 'numba' + else: + assert m._backend == 'numpy' + + def test_auto_fallback_cpu_only(self): + """On CPU-only machines, auto warns and falls back to numba or numpy.""" + with patch('hypyp.sync.base.MPS_AVAILABLE', False), \ + patch('hypyp.sync.base.CUDA_AVAILABLE', False): + with pytest.warns(UserWarning, match="No GPU available"): + m = get_metric('plv', optimization='auto') + if NUMBA_AVAILABLE: + assert m._backend == 'numba' + else: + assert m._backend == 'numpy' + + def test_priority_parameter_propagated_via_get_metric(self): + """get_metric passes priority through to the metric class.""" + m = get_metric('accorr', optimization='auto', priority=['numba']) + assert m._priority == ['numba']