diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index 084e52c..8f3666c 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -36,6 +36,11 @@ softmax, sub, tanh, + max_pool1d, + max_pool3d, + stack, + mean, + median, ) __all__ = [ @@ -76,4 +81,9 @@ "softmax", "sub", "tanh", + "max_pool1d", + "max_pool3d", + "stack", + "mean", + "median", ] diff --git a/src/ntops/kernels/max_pool1d.py b/src/ntops/kernels/max_pool1d.py new file mode 100644 index 0000000..0d1d881 --- /dev/null +++ b/src/ntops/kernels/max_pool1d.py @@ -0,0 +1,73 @@ +import functools +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +def arrangement(input, output, kernel_size, stride, block_size, ceil_mode): + if block_size is None: + block_size = ninetoothed.block_size() + + # input: (N, C, L_in) + # output: (N, C, L_out) + + # 1. 窗口切分 + # dim_sizes: (1, 1, kernel_size) -> 在 L 维度上取 kernel_size 长度 + # strides: (1, 1, stride) -> 在 L 维度上步长为 stride + # floor_mode=not ceil_mode: 决定是否丢弃最后不足一个 kernel 的部分 + input_arranged = input.tile( + (1, 1, kernel_size), + (1, 1, stride), + floor_mode=not ceil_mode + ) + # => (N, C, L_out), dtype=(1, 1, k) + + # 2. 展平与重排 + input_arranged = input_arranged.ravel() + # => (N, C, L_out, 1, 1, k) + + input_arranged = input_arranged.flatten(end_dim=3).flatten(start_dim=1) + # => (N*C*L_out, k) + + # 3. Padding 到最近的 2 的幂次 (用于并行规约) + # 这里的 padding 值由 premake 中的 other="-inf" 决定 + nearest_pow2 = 1 << (kernel_size - 1).bit_length() + input_arranged = input_arranged.tile((1, nearest_pow2)) + # => (..., 1), dtype=(1, nearest_pow2) + + input_arranged.dtype = input_arranged.dtype.squeeze(0) + input_arranged = input_arranged.tile((block_size, -1)) + input_arranged.dtype = input_arranged.dtype.ravel().squeeze(1) + # => (..., 1), dtype=(block_size, nearest_pow2) + + # 4. Output 对齐 + output_arranged = output.tile((1, 1, 1)) # (N, C, L_out) + output_arranged = output_arranged.ravel() + output_arranged = output_arranged.flatten(end_dim=3).flatten(start_dim=1) + output_arranged = output_arranged.tile((block_size, -1)) + output_arranged.dtype = output_arranged.dtype.squeeze(1) + + return input_arranged, output_arranged + +def application(input, output): + # input: (block_size, nearest_pow2) + # output: (block_size, ) + + # 直接取 Max,padding 值为 -inf,不影响结果 + output = ntl.max(input, axis=1) + +def premake(ndim, kernel_size, stride, block_size=None, ceil_mode=False, dtype=None): + arrangement_ = functools.partial( + arrangement, + kernel_size=kernel_size, + stride=stride, + block_size=block_size, + ceil_mode=ceil_mode, + ) + + tensors = ( + # input: MaxPool 填充负无穷 + Tensor(ndim, dtype=dtype, other=float("-inf")), + Tensor(ndim, dtype=dtype), # output + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/max_pool3d.py b/src/ntops/kernels/max_pool3d.py new file mode 100644 index 0000000..26aab33 --- /dev/null +++ b/src/ntops/kernels/max_pool3d.py @@ -0,0 +1,71 @@ +import functools +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +def arrangement(input, output, k_d, k_h, k_w, s_d, s_h, s_w, block_size, ceil_mode): + if block_size is None: + block_size = ninetoothed.block_size() + + # input: (N, C, D_in, H_in, W_in) + # output: (N, C, D_out, H_out, W_out) + + # 1. 窗口切分 (增加 Depth 维度) + input_arranged = input.tile( + (1, 1, k_d, k_h, k_w), + (1, 1, s_d, s_h, s_w), + floor_mode=not ceil_mode + ) + # => (N, C, D_out, H_out, W_out), dtype=(1, 1, k_d, k_h, k_w) + + # 2. 展平与重排 + input_arranged = input_arranged.ravel() + # => (N, C, D_out, H_out, W_out, 1, 1, k_d, k_h, k_w) + + # 注意:这里 end_dim=5,因为前面有 N,C,D,H,W 5个维度需要合并作为 batch 处理 + input_arranged = input_arranged.flatten(end_dim=5).flatten(start_dim=1) + # => (N*C*D_out*H_out*W_out, k_d*k_h*k_w) + + # 3. Padding 到最近的 2 的幂次 (用于并行规约) + # 这里的 padding 值由 premake 中的 other="-inf" 决定 + nearest_pow2 = 1 << (k_d * k_h * k_w - 1).bit_length() + input_arranged = input_arranged.tile((1, nearest_pow2)) + # => (..., 1), dtype=(1, nearest_pow2) + + input_arranged.dtype = input_arranged.dtype.squeeze(0) + input_arranged = input_arranged.tile((block_size, -1)) + input_arranged.dtype = input_arranged.dtype.ravel().squeeze(1) + # => (..., 1), dtype=(block_size, nearest_pow2) + + # 4. Output 对齐 + output_arranged = output.tile((1, 1, 1, 1, 1)) # (N, C, D, H, W) + output_arranged = output_arranged.ravel() + output_arranged = output_arranged.flatten(end_dim=5).flatten(start_dim=1) + output_arranged = output_arranged.tile((block_size, -1)) + output_arranged.dtype = output_arranged.dtype.squeeze(1) + + return input_arranged, output_arranged + +def application(input, output): + # input: (block_size, nearest_pow2) + # output: (block_size, ) + + # Max Pooling 标准操作 + output = ntl.max(input, axis=1) + +def premake(ndim, k_d, k_h, k_w, s_d, s_h, s_w, block_size=None, ceil_mode=False, dtype=None): + arrangement_ = functools.partial( + arrangement, + k_d=k_d, k_h=k_h, k_w=k_w, + s_d=s_d, s_h=s_h, s_w=s_w, + block_size=block_size, + ceil_mode=ceil_mode, + ) + + tensors = ( + # input: MaxPool 填充负无穷 + Tensor(ndim, dtype=dtype, other=float("-inf")), + Tensor(ndim, dtype=dtype), # output + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/mean.py b/src/ntops/kernels/mean.py new file mode 100644 index 0000000..9a0bcf7 --- /dev/null +++ b/src/ntops/kernels/mean.py @@ -0,0 +1,45 @@ +import functools +import ninetoothed.language as ntl +from ninetoothed import Tensor +# 假设 arrangement 逻辑是通用的,直接复用你例子中的 reduction 模块 +from ntops.kernels.reduction import arrangement + +def application(input, output): + # 均值计算的第一步是求和。 + # 我们在 Kernel 内部做 Accumulation,除法留在外部做以保证数值稳定性。 + accumulator = 0.0 + for i in range(input.shape[0]): + block_sum = ntl.sum(input[i], axis=0) + accumulator += block_sum + # 将累加结果转换为输出类型(通常是 float) + output[0] = ntl.cast(accumulator, output.dtype.dtype) + +def premake(ndim, dim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size) + + # Mean 算子要求输出必须是浮点数。 + # 如果传入的 dtype 是整数,这里需要根据 DSL 的特性处理, + # 这里假设 dtype 已经是转换好的 float 类型 (由 Torch 端传入) + tensors = ( + Tensor(ndim, dtype=dtype), # Input + Tensor(ndim, dtype=dtype), # Output + ) + return arrangement_, application, tensors + +# --- Global Reduction (All Elements) 的实现 --- + +def arrangement_all_elements(input, output, block_size=None): + input = input.flatten().tile((block_size,)) + output = output.tile((1,)) + return input, output + +def application_all_elements(input, output): + output[0] = ntl.sum(input, 0) + +def premake_all_elements(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement_all_elements, block_size=block_size) + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(1, dtype=dtype), + ) + return arrangement_, application_all_elements, tensors \ No newline at end of file diff --git a/src/ntops/kernels/median.py b/src/ntops/kernels/median.py new file mode 100644 index 0000000..871b42f --- /dev/null +++ b/src/ntops/kernels/median.py @@ -0,0 +1,56 @@ +import functools +import triton.language as tl # [FIX 1] 导入标准 triton language +import ninetoothed.language as ntl +from ninetoothed import Tensor +from ntops.kernels.reduction import arrangement + +def application(input, values, indices, loop_k): + val_block = input[0] + + # 初始化用于查找的 working_val + working_val = -val_block + sentinel = float("-inf") + + # 原始索引 + idx_block = ntl.arange(0, val_block.shape[0]) + + # 初始化结果 buffer (Scalar) + final_val = ntl.zeros([], dtype=val_block.dtype) + + # [FIX 2] 使用 tl.int32 而不是字符串 "int32" + final_idx = ntl.zeros([], dtype=tl.int32) + + # 循环 + for i in range(loop_k + 1): + current_max_val = ntl.max(working_val, axis=0) + current_max_idx = ntl.argmax(working_val, axis=0) # 返回 int32 + + if i == loop_k: + real_val = -current_max_val + final_val = ntl.cast(real_val, values.dtype.dtype) + final_idx = current_max_idx # int32 -> int32 + + # Mask 逻辑 + mask_selected = idx_block == current_max_idx + updated_working_val = ntl.where(mask_selected, sentinel, working_val) + working_val = ntl.cast(updated_working_val, working_val.dtype) + + # 写回输出 + values[0] = final_val + # Cast int32 -> int64 (或输出需要的类型) + indices[0] = ntl.cast(final_idx, indices.dtype.dtype) + +def premake( + ndim, dim, loop_k, dtype=None, indices_dtype=None, block_size=None +): + arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size) + pad_val = float("inf") + + tensors = ( + Tensor(ndim, dtype=dtype, other=pad_val), # Input + Tensor(ndim, dtype=dtype), # Output Values + Tensor(ndim, dtype=indices_dtype), # Output Indices + Tensor(0, constexpr=True, value=loop_k), # Loop bound + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/kernels/stack.py b/src/ntops/kernels/stack.py new file mode 100644 index 0000000..60676c4 --- /dev/null +++ b/src/ntops/kernels/stack.py @@ -0,0 +1,46 @@ +import functools +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor + +def arrangement(input, output, block_size): + if block_size is None: + block_size = ninetoothed.block_size() + + # input: (...,) 任意维度的单个切片 + # output: (...,) 对应的输出切片 + + # 1. 展平 (Flatten) + # Stack/Copy 操作是 Element-wise 的,不依赖空间结构, + # 所以我们可以将数据视为一维连续内存进行处理,以获得最大内存带宽。 + input_arranged = input.flatten() + output_arranged = output.flatten() + + # 2. 切分 (Tile) + # 将一维数据切分为大小为 block_size 的块 + # 形状变化: (Total_Elements) -> (Num_Blocks, Block_Size) + input_arranged = input_arranged.tile((block_size,)) + output_arranged = output_arranged.tile((block_size,)) + + return input_arranged, output_arranged + +def application(input, output): + # input: (block_size, ) + # output: (block_size, ) + + # 简单的 Element-wise 赋值 + # DSL 会将其翻译为 load(input) -> store(output) + output = input + +def premake(ndim, block_size=None, dtype=None): + arrangement_ = functools.partial( + arrangement, + block_size=block_size, + ) + + tensors = ( + Tensor(ndim, dtype=dtype), # input (source tensor) + Tensor(ndim, dtype=dtype), # output (destination slice) + ) + + return arrangement_, application, tensors \ No newline at end of file diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 702877e..4361130 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -36,6 +36,11 @@ from ntops.torch.softmax import softmax from ntops.torch.sub import sub from ntops.torch.tanh import tanh +from ntops.torch.max_pool1d import max_pool1d +from ntops.torch.max_pool3d import max_pool3d +from ntops.torch.stack import stack +from ntops.torch.mean import mean +from ntops.torch.median import median __all__ = [ "abs", @@ -75,5 +80,10 @@ "sin", "softmax", "sub", - "tanh", + "tanh", + "max_pool1d", + "max_pool3d", + "stack", + "mean", + "median", ] diff --git a/src/ntops/torch/max_pool1d.py b/src/ntops/torch/max_pool1d.py new file mode 100644 index 0000000..9a4d438 --- /dev/null +++ b/src/ntops/torch/max_pool1d.py @@ -0,0 +1,61 @@ +import math +import torch +import torch.nn.functional as F +import ntops +from ntops.torch.utils import _cached_make + +def max_pool1d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False): + # 维度检查 + is_3d = input.ndim == 3 + if input.ndim == 2: + input = input.unsqueeze(0) # (C, L) -> (1, C, L) + + assert input.ndim == 3, "Input tensor must be 3-dimensional (N, C, L) or 2-dimensional (C, L)" + + # 参数标准化 + if stride is None: + stride = kernel_size + + # 处理 Tuple 参数 (虽然 1d 通常是 int,但 torch 允许 (int,)) + if isinstance(kernel_size, tuple): kernel_size = kernel_size[0] + if isinstance(stride, tuple): stride = stride[0] + if isinstance(padding, tuple): padding = padding[0] + if isinstance(dilation, tuple): dilation = dilation[0] + + assert dilation == 1, "Currently only dilation=1 is supported in this DSL implementation" + + # 处理 Explicit Padding + # 如果有 padding,先用 -inf 填充 input + if padding > 0: + input = F.pad(input, (padding, padding), value=float("-inf")) + + L_in = input.shape[-1] + + # 计算输出长度 + if ceil_mode: + L_out = math.ceil((L_in - kernel_size + stride) / stride) + else: + L_out = math.floor((L_in - kernel_size + stride) / stride) + + # 构造 Output + output_shape = (input.shape[0], input.shape[1], L_out) + output = torch.empty(output_shape, dtype=input.dtype, device=input.device) + + block_size = 1024 + + kernel = _cached_make( + ntops.kernels.max_pool1d.premake, + input.ndim, + kernel_size, + stride, + block_size=block_size, + ceil_mode=ceil_mode, + dtype=input.dtype + ) + + kernel(input, output) + + if not is_3d: + output = output.squeeze(0) + + return output \ No newline at end of file diff --git a/src/ntops/torch/max_pool3d.py b/src/ntops/torch/max_pool3d.py new file mode 100644 index 0000000..1bf4931 --- /dev/null +++ b/src/ntops/torch/max_pool3d.py @@ -0,0 +1,85 @@ +import math +import torch +import torch.nn.functional as F +import ntops +from ntops.torch.utils import _cached_make + +def max_pool3d(input, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False): + assert input.ndim == 5, "Input tensor must be 5-dimensional (N, C, D, H, W)" + assert dilation == 1 or dilation == (1, 1, 1), "Currently only dilation=1 is supported" + + # --- 参数归一化处理 --- + def _triple(x): + if isinstance(x, int): return (x, x, x) + if len(x) == 1: return (x[0], x[0], x[0]) + return x + + k_d, k_h, k_w = _triple(kernel_size) + + if stride is None: + s_d, s_h, s_w = k_d, k_h, k_w + else: + s_d, s_h, s_w = _triple(stride) + + pad_d, pad_h, pad_w = _triple(padding) + + # --- 精确计算输出尺寸和所需的右侧 Padding --- + def _calc_dim_and_pad(in_dim, k, s, p, ceil_mode): + # 1. PyTorch 标准输出尺寸计算公式 + if ceil_mode: + out_dim = math.ceil((in_dim + 2 * p - k) / s) + 1 + else: + out_dim = math.floor((in_dim + 2 * p - k) / s) + 1 + + # 2. PyTorch 特殊边界检查 (Ceil Mode 下排除纯 Padding 的窗口) + # 逻辑: 如果最后一个窗口的起始位置 >= 原始输入长度 + 左padding,则该窗口无效 + if ceil_mode: + if (out_dim - 1) * s >= in_dim + p: + out_dim -= 1 + + # 3. 反推需要的总长度,以满足 DSL 的 floor tiling + # 我们需要: (out_dim - 1) * s + k + needed_len = (out_dim - 1) * s + k + + # 4. 计算右侧需要补多少 + # 当前已有: in_dim + p (左侧 padding) + current_len = in_dim + p + pad_right = max(0, needed_len - current_len) + + return out_dim, pad_right + + D_in, H_in, W_in = input.shape[-3], input.shape[-2], input.shape[-1] + + D_out, pad_r_d = _calc_dim_and_pad(D_in, k_d, s_d, pad_d, ceil_mode) + H_out, pad_r_h = _calc_dim_and_pad(H_in, k_h, s_h, pad_h, ceil_mode) + W_out, pad_r_w = _calc_dim_and_pad(W_in, k_w, s_w, pad_w, ceil_mode) + + # --- 应用 Explicit Padding --- + # F.pad 顺序: (W_left, W_right, H_top, H_bot, D_front, D_back) + if any(p > 0 for p in [pad_w, pad_r_w, pad_h, pad_r_h, pad_d, pad_r_d]): + input = F.pad( + input, + (pad_w, pad_r_w, pad_h, pad_r_h, pad_d, pad_r_d), + value=float("-inf") + ) + + output_shape = (input.shape[0], input.shape[1], D_out, H_out, W_out) + output = torch.empty(output_shape, dtype=input.dtype, device=input.device) + + block_size = 1024 + + # 注意: 这里 ceil_mode 永远传 False + # 因为我们已经通过 Explicit Padding 确保了 floor 切分能得到正确的 output size + kernel = _cached_make( + ntops.kernels.max_pool3d.premake, + input.ndim, + k_d, k_h, k_w, + s_d, s_h, s_w, + block_size=block_size, + ceil_mode=False, + dtype=input.dtype + ) + + kernel(input, output) + + return output \ No newline at end of file diff --git a/src/ntops/torch/mean.py b/src/ntops/torch/mean.py new file mode 100644 index 0000000..a666636 --- /dev/null +++ b/src/ntops/torch/mean.py @@ -0,0 +1,123 @@ +import math +import torch +import ntops +from ntops.torch.utils import _cached_make +# 引入上面定义的 kernel +import ntops.kernels.mean + +def next_power_of_2(n): + if n == 0: + return 1 + return 1 << (n - 1).bit_length() + +def get_optimal_block_size(dim_size): + target_size = next_power_of_2(dim_size) + if target_size > 1024: + target_size = 1024 + if target_size < 32: + target_size = 32 + return target_size + +def mean( + input, + dim: int | tuple[int] | list[int] | None = None, + keepdim=False, + *, + dtype=None, + out=None, +): + # 1. 确定计算使用的 dtype + # Mean 操作如果输入是整数,必须提升为浮点数 + if dtype is None: + if input.dtype.is_floating_point: + computation_dtype = input.dtype + else: + computation_dtype = torch.float32 + else: + computation_dtype = dtype + + # 2. 计算用于除法的元素总数 N + if dim is None: + num_elements = input.numel() + else: + if isinstance(dim, int): + dims = (dim,) + else: + dims = tuple(dim) + + num_elements = 1 + for d in dims: + num_elements *= input.shape[d] + + # 3. Kernel 计算 (本质是 Sum,但是使用 computation_dtype) + + # --- Case A: Global Mean (所有元素) --- + if dim is None: + current = input + block_size = get_optimal_block_size(current.numel()) + + # 递归规约 (Global Reduction) + while current.numel() > 1: + output_len = math.ceil(current.numel() / block_size) + output = torch.empty((output_len,), dtype=computation_dtype, device=current.device) + + kernel = _cached_make( + ntops.kernels.mean.premake_all_elements, + current.ndim, + computation_dtype, # 确保 kernel 使用浮点 + block_size, + ) + kernel(current, output) + current = output + + result_sum = current.view(()) + + # 4. 执行除法 + result = result_sum.div(num_elements) + + if out is not None: + out.copy_(result) + return out + return result + + # --- Case B: Dim Mean (指定维度) --- + else: + output_shape = list(input.shape) + for d in dims: + if d < 0: + d += input.ndim + output_shape[d] = 1 + + temp_out = torch.empty(output_shape, dtype=computation_dtype, device=input.device) + block_size = get_optimal_block_size(output_shape[dims[0]]) + + kernel = _cached_make( + ntops.kernels.mean.premake, + input.ndim, + dims, + computation_dtype, # 确保 kernel 使用浮点 + block_size + ) + kernel(input, temp_out) + + # 4. 执行除法 (In-place 以优化性能) + temp_out.div_(num_elements) + + if not keepdim: + dims_to_remove = sorted( + [d if d >= 0 else d + input.ndim for d in dims], reverse=True + ) + final_shape = list(output_shape) + for d in dims_to_remove: + del final_shape[d] + + if not final_shape: + temp_out = temp_out.view(()) + else: + temp_out = temp_out.view(final_shape) + + if out is not None: + out.copy_(temp_out) + return out + + return temp_out \ No newline at end of file diff --git a/src/ntops/torch/median.py b/src/ntops/torch/median.py new file mode 100644 index 0000000..2f16eb6 --- /dev/null +++ b/src/ntops/torch/median.py @@ -0,0 +1,83 @@ +import torch +import ntops +from ntops.torch.utils import _cached_make + +def next_power_of_2(n): + if n == 0: + return 1 + return 1 << (n - 1).bit_length() + +def get_optimal_block_size(dim_size): + target_size = next_power_of_2(dim_size) + if target_size > 1024: + target_size = 1024 + if target_size < 32: + target_size = 32 + return target_size + +def median(input, dim=-1, keepdim=False, *, out=None): + """ + Args: + input: 输入张量 + dim: 计算中位数的维度 + keepdim: 是否保持输出维度 + out: (values, indices) 元组 (可选) + Returns: + (values, indices) namedtuple or tuple + """ + dtype = input.dtype + indices_dtype = torch.int64 + + # 处理 dim + if dim is None: + raise NotImplementedError("median currently requires a specific dim") + + if dim < 0: + dim += input.ndim + + target_dim = dim + dim_size = input.shape[target_dim] + + # 计算中位数在排序后的索引位置 (PyTorch 默认行为:(N-1) // 2) + # 比如 N=3, idx=1; N=4, idx=1 + median_rank = (dim_size - 1) // 2 + + input_logic = input + block_size = get_optimal_block_size(dim_size) + + # 准备 Output + # Median 操作会 reduce 掉 target_dim,或者 keepdim=True + # 但 Kernel 这里的 arrangement 是一对一的 (Input Block -> Output Block) + # 通常 reduction kernel 输出的 shape 在 target_dim 上是 1 或被 squeeze + # 这里我们按照 keepdim=True 的形状申请,最后根据参数 squeeze + + output_shape = list(input.shape) + output_shape[target_dim] = 1 # 结果只有 1 个元素 + + if out is not None: + values_logic, indices_logic = out + else: + values_logic = torch.empty(output_shape, dtype=dtype, device=input.device) + indices_logic = torch.empty(output_shape, dtype=indices_dtype, device=input.device) + + # 构建 Kernel + # 注意:我们将 median_rank 作为 loop_k 传入 premake + kernel = _cached_make( + ntops.kernels.median.premake, # 假设上面的 kernel 代码在这个路径 + input_logic.ndim, + target_dim, + median_rank, # loop_k + dtype, + indices_dtype, + block_size, + ) + + # 启动 Kernel + kernel(input_logic, values_logic, indices_logic, median_rank) + + # 处理 keepdim + if not keepdim: + values_logic = values_logic.squeeze(dim) + indices_logic = indices_logic.squeeze(dim) + + return torch.return_types.median((values_logic, indices_logic)) \ No newline at end of file diff --git a/src/ntops/torch/stack.py b/src/ntops/torch/stack.py new file mode 100644 index 0000000..76ae6a7 --- /dev/null +++ b/src/ntops/torch/stack.py @@ -0,0 +1,48 @@ +import torch +import ntops +from ntops.torch.utils import _cached_make + +def stack(tensors, dim=0): + if not tensors: + raise ValueError("stack expects a non-empty list of tensors") + + first_tensor = tensors[0] + target_dtype = first_tensor.dtype + target_device = first_tensor.device + input_shape = first_tensor.shape + ndim = first_tensor.ndim + + # 1. 校验所有输入张量的形状和类型必须一致 + for t in tensors: + assert t.shape == input_shape, f"Shape mismatch: expected {input_shape}, got {t.shape}" + assert t.dtype == target_dtype, "Dtype mismatch in input tensors" + assert t.device == target_device, "Device mismatch in input tensors" + + # 2. 计算输出形状 + # stack 会在 dim 维度插入大小为 len(tensors) 的新维度 + output_shape = list(input_shape) + output_shape.insert(dim, len(tensors)) + + # 3. 分配输出内存 + output = torch.empty(output_shape, dtype=target_dtype, device=target_device) + + # 4. 准备 Kernel + # 注意:这里我们使用 input 的 ndim,因为 Kernel 处理的是切片 + block_size = 1024 + kernel = _cached_make( + ntops.kernels.stack.premake, # 假设上面的代码保存在这里 + ndim, + block_size=block_size, + dtype=target_dtype + ) + + # 5. 循环执行 (Iterative Execution) + # 对于每个输入张量,将其拷贝到 output 对应的切片中 + for i, t in enumerate(tensors): + # output.select(dim, i) 返回的是一个 View(视图), + # 它共享 output 的内存,但形状与 input 一致。 + # Kernel 就像处理两个普通张量一样处理它们。 + output_slice = output.select(dim, i) + kernel(t, output_slice) + + return output \ No newline at end of file diff --git a/tests/test_max_pool1d.py b/tests/test_max_pool1d.py new file mode 100644 index 0000000..bb16487 --- /dev/null +++ b/tests/test_max_pool1d.py @@ -0,0 +1,52 @@ +import random +import pytest +import torch +import ntops +from tests.skippers import skip_if_cuda_not_available + +@skip_if_cuda_not_available +@pytest.mark.parametrize("ceil_mode", [False, True]) +@pytest.mark.parametrize("padding", [0, 1, 2]) +@pytest.mark.parametrize("kernel_stride", [(2, 2), (3, 2), (3, 1)]) # (kernel, stride) +def test_max_pool1d(ceil_mode, padding, kernel_stride): + device = "cuda" + dtype = torch.float32 + + kernel_size, stride = kernel_stride + + batch = random.randint(1, 4) + channels = random.randint(1, 4) + length = random.randint(10, 50) + + input_tensor = torch.randn((batch, channels, length), device=device, dtype=dtype) + + # Ntops implementation + try: + ntops_output = ntops.torch.max_pool1d( + input_tensor, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode + ) + except Exception as e: + # 某些极端情况下的尺寸计算可能会导致 torch 原生报错, + # 如果是我们算子的问题则会在这里抛出,测试需捕获对比 + pytest.fail(f"Kernel execution failed: {e}") + + # Reference implementation + try: + reference_output = torch.nn.functional.max_pool1d( + input_tensor, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode + ) + except RuntimeError: + # 如果 Torch 因为输出尺寸计算报错 (例如 padding 此时过大导致 L_out < 0 等), + # 我们这里也应该忽略或确保我们的实现有相同行为。 + # 这里简单起见,如果 torch 挂了,我们跳过这次 assert + return + + assert torch.allclose(ntops_output, reference_output, atol=1e-3, rtol=1e-3) \ No newline at end of file diff --git a/tests/test_max_pool3d.py b/tests/test_max_pool3d.py new file mode 100644 index 0000000..204776f --- /dev/null +++ b/tests/test_max_pool3d.py @@ -0,0 +1,48 @@ +import random +import pytest +import torch +import ntops +from tests.skippers import skip_if_cuda_not_available + +@skip_if_cuda_not_available +@pytest.mark.parametrize("ceil_mode", [False, True]) +@pytest.mark.parametrize("padding", [0, 1]) +@pytest.mark.parametrize("use_tuple", [False, True]) +def test_max_pool3d(ceil_mode, padding, use_tuple): + device = "cuda" + dtype = torch.float32 + + batch = random.randint(1, 2) + channels = random.randint(1, 3) + depth = random.randint(8, 16) + height = random.randint(8, 16) + width = random.randint(8, 16) + + if use_tuple: + kernel_size = (random.randint(2, 3), random.randint(2, 3), random.randint(2, 3)) + stride = (random.randint(1, 2), random.randint(1, 2), random.randint(1, 2)) + else: + kernel_size = random.randint(2, 3) + stride = random.randint(1, 2) + + input_tensor = torch.randn((batch, channels, depth, height, width), device=device, dtype=dtype) + + # Ntops implementation + ntops_output = ntops.torch.max_pool3d( + input_tensor, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode + ) + + # Reference implementation + reference_output = torch.nn.functional.max_pool3d( + input_tensor, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode + ) + + assert torch.allclose(ntops_output, reference_output, atol=1e-3, rtol=1e-3) \ No newline at end of file diff --git a/tests/test_mean.py b/tests/test_mean.py new file mode 100644 index 0000000..26ef10f --- /dev/null +++ b/tests/test_mean.py @@ -0,0 +1,56 @@ +import random +import pytest +import torch +import ntops +# 假设你在 ntops 包里导出了 mean +from ntops.torch.mean import mean as ntops_mean + +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +@pytest.mark.parametrize("keepdim", (False, True)) +def test_mean_dim(shape, dtype, device, rtol, atol, keepdim): + # Mean 测试通常需要稍微放宽 float16 的误差容忍度,因为除法会引入额外误差 + if dtype == torch.float16: + atol = max(atol, 1e-3) + rtol = max(rtol, 1e-3) + + input_tensor = torch.randn(shape, dtype=dtype, device=device) + dim = random.randint(0, input_tensor.ndim - 1) + + if random.choice((True, False)): + dim = dim - input_tensor.ndim + + ntops_value = ntops_mean(input_tensor, dim=dim, keepdim=keepdim) + reference_value = torch.mean(input_tensor, dim=dim, keepdim=keepdim) + + assert torch.allclose(ntops_value, reference_value, rtol=rtol, atol=atol) + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_mean_global(shape, dtype, device, rtol, atol): + if dtype == torch.float16: + atol = max(atol, 1e-3) + rtol = max(rtol, 1e-3) + + input_tensor = torch.randn(shape, dtype=dtype, device=device) + + ntops_value = ntops_mean(input_tensor) + reference_value = torch.mean(input_tensor) + + assert torch.allclose(ntops_value, reference_value, rtol=rtol, atol=atol) + +@skip_if_cuda_not_available +def test_mean_int_input(): + # 特别测试:整数输入应该产生浮点输出 + device = "cuda" + shape = (1024, 1024) + input_tensor = torch.randint(0, 10, shape, device=device, dtype=torch.int32) + + ntops_value = ntops_mean(input_tensor) + reference_value = torch.mean(input_tensor.float()) # torch.mean(int)在旧版本可能报错或行为不同,通常需转float + + assert ntops_value.is_floating_point() + assert torch.allclose(ntops_value, reference_value) \ No newline at end of file diff --git a/tests/test_median.py b/tests/test_median.py new file mode 100644 index 0000000..93bc0d7 --- /dev/null +++ b/tests/test_median.py @@ -0,0 +1,51 @@ +import random +import pytest +import torch +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import _random_shape + +def generate_median_args(): + args = [] + # Median 也是一种选择排序,复杂度较高,测试时控制 dim_size + for dtype in (torch.float32, torch.float16): + device = "cuda" + rtol, atol = (1e-3, 1e-3) if dtype == torch.float32 else (1e-2, 1e-2) + + for ndim in range(1, 4): + for _ in range(5): + shape = _random_shape(ndim) + dim = random.randint(0, ndim - 1) + + # 限制 dim_size,原因同 argsort + if shape[dim] > 128: + shape_list = list(shape) + shape_list[dim] = random.randint(10, 128) + shape = tuple(shape_list) + + args.append((shape, dim, dtype, device, rtol, atol)) + return "shape, dim, dtype, device, rtol, atol", args + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_median_args()) +def test_median(shape, dim, dtype, device, rtol, atol): + input_tensor = torch.randn(shape, dtype=dtype, device=device) + + # 1. 调用 ntops 实现 + ntops_vals, ntops_idxs = ntops.torch.median(input_tensor, dim=dim, keepdim=False) + + # 2. 调用 PyTorch 参考实现 + ref_vals, ref_idxs = torch.median(input_tensor, dim=dim, keepdim=False) + + # 3. 验证值 (Values) + # 浮点数比较需要容差 + assert torch.allclose(ntops_vals, ref_vals, rtol=rtol, atol=atol), "Values mismatch" + + # 4. 验证索引 (Indices) + # 注意:如果有重复值,median 的 index 可能不唯一。 + # 更稳健的测试方法是:用 ntops 返回的 index 去 input 取值,看取出来的值是否等于 ref_val + gathered_vals = torch.gather(input_tensor, dim, ntops_idxs.unsqueeze(dim)).squeeze(dim) + assert torch.allclose(gathered_vals, ref_vals, rtol=rtol, atol=atol), "Gathered values from indices mismatch" + + # 如果数据没有重复值,可以直接比较 index (可选) + # assert torch.equal(ntops_idxs, ref_idxs) \ No newline at end of file diff --git a/tests/test_stack.py b/tests/test_stack.py new file mode 100644 index 0000000..592cc71 --- /dev/null +++ b/tests/test_stack.py @@ -0,0 +1,37 @@ +import random +import pytest +import torch +import ntops +from tests.skippers import skip_if_cuda_not_available + +@skip_if_cuda_not_available +@pytest.mark.parametrize("dim", [0, 1, 2]) +def test_stack(dim): + device = "cuda" + dtype = torch.float32 + + # 构造测试参数 + num_tensors = random.randint(2, 5) + shape = (random.randint(16, 64), random.randint(16, 64), random.randint(16, 64)) + + # 确保 dim 不越界 (ndim + 1 因为 stack 会增加维度) + if dim > len(shape): + return + + # 创建输入列表 + tensors = [ + torch.randn(shape, device=device, dtype=dtype) + for _ in range(num_tensors) + ] + + # Reference implementation + reference_output = torch.stack(tensors, dim=dim) + + # Ntops implementation + ntops_output = ntops.torch.stack(tensors, dim=dim) + + # 验证一致性 + assert torch.allclose(ntops_output, reference_output, atol=1e-6, rtol=1e-6) + + # 验证内存是否连续(Stack 的输出通常是连续的) + assert ntops_output.is_contiguous() \ No newline at end of file