Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
softmax,
sub,
tanh,
max_pool1d,
max_pool3d,
stack,
mean,
median,
)

__all__ = [
Expand Down Expand Up @@ -76,4 +81,9 @@
"softmax",
"sub",
"tanh",
"max_pool1d",
"max_pool3d",
"stack",
"mean",
"median",
]
73 changes: 73 additions & 0 deletions src/ntops/kernels/max_pool1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import functools
import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor

def arrangement(input, output, kernel_size, stride, block_size, ceil_mode):
if block_size is None:
block_size = ninetoothed.block_size()

# input: (N, C, L_in)
# output: (N, C, L_out)

# 1. 窗口切分
# dim_sizes: (1, 1, kernel_size) -> 在 L 维度上取 kernel_size 长度
# strides: (1, 1, stride) -> 在 L 维度上步长为 stride
# floor_mode=not ceil_mode: 决定是否丢弃最后不足一个 kernel 的部分
input_arranged = input.tile(
(1, 1, kernel_size),
(1, 1, stride),
floor_mode=not ceil_mode
)
# => (N, C, L_out), dtype=(1, 1, k)

# 2. 展平与重排
input_arranged = input_arranged.ravel()
# => (N, C, L_out, 1, 1, k)

input_arranged = input_arranged.flatten(end_dim=3).flatten(start_dim=1)
# => (N*C*L_out, k)

# 3. Padding 到最近的 2 的幂次 (用于并行规约)
# 这里的 padding 值由 premake 中的 other="-inf" 决定
nearest_pow2 = 1 << (kernel_size - 1).bit_length()
input_arranged = input_arranged.tile((1, nearest_pow2))
# => (..., 1), dtype=(1, nearest_pow2)

input_arranged.dtype = input_arranged.dtype.squeeze(0)
input_arranged = input_arranged.tile((block_size, -1))
input_arranged.dtype = input_arranged.dtype.ravel().squeeze(1)
# => (..., 1), dtype=(block_size, nearest_pow2)

# 4. Output 对齐
output_arranged = output.tile((1, 1, 1)) # (N, C, L_out)
output_arranged = output_arranged.ravel()
output_arranged = output_arranged.flatten(end_dim=3).flatten(start_dim=1)
output_arranged = output_arranged.tile((block_size, -1))
output_arranged.dtype = output_arranged.dtype.squeeze(1)

return input_arranged, output_arranged

def application(input, output):
# input: (block_size, nearest_pow2)
# output: (block_size, )

# 直接取 Max,padding 值为 -inf,不影响结果
output = ntl.max(input, axis=1)

def premake(ndim, kernel_size, stride, block_size=None, ceil_mode=False, dtype=None):
arrangement_ = functools.partial(
arrangement,
kernel_size=kernel_size,
stride=stride,
block_size=block_size,
ceil_mode=ceil_mode,
)

tensors = (
# input: MaxPool 填充负无穷
Tensor(ndim, dtype=dtype, other=float("-inf")),
Tensor(ndim, dtype=dtype), # output
)

return arrangement_, application, tensors
71 changes: 71 additions & 0 deletions src/ntops/kernels/max_pool3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import functools
import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor

def arrangement(input, output, k_d, k_h, k_w, s_d, s_h, s_w, block_size, ceil_mode):
if block_size is None:
block_size = ninetoothed.block_size()

# input: (N, C, D_in, H_in, W_in)
# output: (N, C, D_out, H_out, W_out)

# 1. 窗口切分 (增加 Depth 维度)
input_arranged = input.tile(
(1, 1, k_d, k_h, k_w),
(1, 1, s_d, s_h, s_w),
floor_mode=not ceil_mode
)
# => (N, C, D_out, H_out, W_out), dtype=(1, 1, k_d, k_h, k_w)

# 2. 展平与重排
input_arranged = input_arranged.ravel()
# => (N, C, D_out, H_out, W_out, 1, 1, k_d, k_h, k_w)

# 注意:这里 end_dim=5,因为前面有 N,C,D,H,W 5个维度需要合并作为 batch 处理
input_arranged = input_arranged.flatten(end_dim=5).flatten(start_dim=1)
# => (N*C*D_out*H_out*W_out, k_d*k_h*k_w)

# 3. Padding 到最近的 2 的幂次 (用于并行规约)
# 这里的 padding 值由 premake 中的 other="-inf" 决定
nearest_pow2 = 1 << (k_d * k_h * k_w - 1).bit_length()
input_arranged = input_arranged.tile((1, nearest_pow2))
# => (..., 1), dtype=(1, nearest_pow2)

input_arranged.dtype = input_arranged.dtype.squeeze(0)
input_arranged = input_arranged.tile((block_size, -1))
input_arranged.dtype = input_arranged.dtype.ravel().squeeze(1)
# => (..., 1), dtype=(block_size, nearest_pow2)

# 4. Output 对齐
output_arranged = output.tile((1, 1, 1, 1, 1)) # (N, C, D, H, W)
output_arranged = output_arranged.ravel()
output_arranged = output_arranged.flatten(end_dim=5).flatten(start_dim=1)
output_arranged = output_arranged.tile((block_size, -1))
output_arranged.dtype = output_arranged.dtype.squeeze(1)

return input_arranged, output_arranged

def application(input, output):
# input: (block_size, nearest_pow2)
# output: (block_size, )

# Max Pooling 标准操作
output = ntl.max(input, axis=1)

def premake(ndim, k_d, k_h, k_w, s_d, s_h, s_w, block_size=None, ceil_mode=False, dtype=None):
arrangement_ = functools.partial(
arrangement,
k_d=k_d, k_h=k_h, k_w=k_w,
s_d=s_d, s_h=s_h, s_w=s_w,
block_size=block_size,
ceil_mode=ceil_mode,
)

tensors = (
# input: MaxPool 填充负无穷
Tensor(ndim, dtype=dtype, other=float("-inf")),
Tensor(ndim, dtype=dtype), # output
)

return arrangement_, application, tensors
45 changes: 45 additions & 0 deletions src/ntops/kernels/mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import functools
import ninetoothed.language as ntl
from ninetoothed import Tensor
# 假设 arrangement 逻辑是通用的,直接复用你例子中的 reduction 模块
from ntops.kernels.reduction import arrangement

def application(input, output):
# 均值计算的第一步是求和。
# 我们在 Kernel 内部做 Accumulation,除法留在外部做以保证数值稳定性。
accumulator = 0.0
for i in range(input.shape[0]):
block_sum = ntl.sum(input[i], axis=0)
accumulator += block_sum
# 将累加结果转换为输出类型(通常是 float)
output[0] = ntl.cast(accumulator, output.dtype.dtype)

def premake(ndim, dim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size)

# Mean 算子要求输出必须是浮点数。
# 如果传入的 dtype 是整数,这里需要根据 DSL 的特性处理,
# 这里假设 dtype 已经是转换好的 float 类型 (由 Torch 端传入)
tensors = (
Tensor(ndim, dtype=dtype), # Input
Tensor(ndim, dtype=dtype), # Output
)
return arrangement_, application, tensors

# --- Global Reduction (All Elements) 的实现 ---

def arrangement_all_elements(input, output, block_size=None):
input = input.flatten().tile((block_size,))
output = output.tile((1,))
return input, output

def application_all_elements(input, output):
output[0] = ntl.sum(input, 0)

def premake_all_elements(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement_all_elements, block_size=block_size)
tensors = (
Tensor(ndim, dtype=dtype),
Tensor(1, dtype=dtype),
)
return arrangement_, application_all_elements, tensors
56 changes: 56 additions & 0 deletions src/ntops/kernels/median.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import functools
import triton.language as tl # [FIX 1] 导入标准 triton language
import ninetoothed.language as ntl
from ninetoothed import Tensor
from ntops.kernels.reduction import arrangement

def application(input, values, indices, loop_k):
val_block = input[0]

# 初始化用于查找的 working_val
working_val = -val_block
sentinel = float("-inf")

# 原始索引
idx_block = ntl.arange(0, val_block.shape[0])

# 初始化结果 buffer (Scalar)
final_val = ntl.zeros([], dtype=val_block.dtype)

# [FIX 2] 使用 tl.int32 而不是字符串 "int32"
final_idx = ntl.zeros([], dtype=tl.int32)

# 循环
for i in range(loop_k + 1):
current_max_val = ntl.max(working_val, axis=0)
current_max_idx = ntl.argmax(working_val, axis=0) # 返回 int32

if i == loop_k:
real_val = -current_max_val
final_val = ntl.cast(real_val, values.dtype.dtype)
final_idx = current_max_idx # int32 -> int32

# Mask 逻辑
mask_selected = idx_block == current_max_idx
updated_working_val = ntl.where(mask_selected, sentinel, working_val)
working_val = ntl.cast(updated_working_val, working_val.dtype)

# 写回输出
values[0] = final_val
# Cast int32 -> int64 (或输出需要的类型)
indices[0] = ntl.cast(final_idx, indices.dtype.dtype)

def premake(
ndim, dim, loop_k, dtype=None, indices_dtype=None, block_size=None
):
arrangement_ = functools.partial(arrangement, dim=dim, block_size=block_size)
pad_val = float("inf")

tensors = (
Tensor(ndim, dtype=dtype, other=pad_val), # Input
Tensor(ndim, dtype=dtype), # Output Values
Tensor(ndim, dtype=indices_dtype), # Output Indices
Tensor(0, constexpr=True, value=loop_k), # Loop bound
)

return arrangement_, application, tensors
46 changes: 46 additions & 0 deletions src/ntops/kernels/stack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import functools
import ninetoothed
import ninetoothed.language as ntl
from ninetoothed import Tensor

def arrangement(input, output, block_size):
if block_size is None:
block_size = ninetoothed.block_size()

# input: (...,) 任意维度的单个切片
# output: (...,) 对应的输出切片

# 1. 展平 (Flatten)
# Stack/Copy 操作是 Element-wise 的,不依赖空间结构,
# 所以我们可以将数据视为一维连续内存进行处理,以获得最大内存带宽。
input_arranged = input.flatten()
output_arranged = output.flatten()

# 2. 切分 (Tile)
# 将一维数据切分为大小为 block_size 的块
# 形状变化: (Total_Elements) -> (Num_Blocks, Block_Size)
input_arranged = input_arranged.tile((block_size,))
output_arranged = output_arranged.tile((block_size,))

return input_arranged, output_arranged

def application(input, output):
# input: (block_size, )
# output: (block_size, )

# 简单的 Element-wise 赋值
# DSL 会将其翻译为 load(input) -> store(output)
output = input

def premake(ndim, block_size=None, dtype=None):
arrangement_ = functools.partial(
arrangement,
block_size=block_size,
)

tensors = (
Tensor(ndim, dtype=dtype), # input (source tensor)
Tensor(ndim, dtype=dtype), # output (destination slice)
)

return arrangement_, application, tensors
12 changes: 11 additions & 1 deletion src/ntops/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
from ntops.torch.softmax import softmax
from ntops.torch.sub import sub
from ntops.torch.tanh import tanh
from ntops.torch.max_pool1d import max_pool1d
from ntops.torch.max_pool3d import max_pool3d
from ntops.torch.stack import stack
from ntops.torch.mean import mean
from ntops.torch.median import median

__all__ = [
"abs",
Expand Down Expand Up @@ -75,5 +80,10 @@
"sin",
"softmax",
"sub",
"tanh",
"tanh",
"max_pool1d",
"max_pool3d",
"stack",
"mean",
"median",
]
Loading