Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit 340e791

Browse files
Add translate variant for offset
1 parent 8ce4dab commit 340e791

File tree

3 files changed

+37
-16
lines changed

3 files changed

+37
-16
lines changed

src/device/matmul_kernels/kernel.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function matmul_impl(a, b, c, d,
4545

4646
@unroll for i = 1 : NUM_FRAGMENTS_M
4747
@unroll for j = 1 : NUM_FRAGMENTS_N
48-
tile = translate(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
48+
tile = translate_offset(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
4949
@inbounds c_frags[i, j] = transf_sh2rf_c(Operator.load_c(OPERATOR, SHARED_C_LAYOUT, shmem_c, tile), tile)
5050
end
5151
end
@@ -84,15 +84,15 @@ function matmul_impl(a, b, c, d,
8484
a_frags = MArray{Tuple{NUM_FRAGMENTS_M}, Operator.fragtype_a(OPERATOR, SHARED_A_LAYOUT)}(undef)
8585

8686
@unroll for i = 1 : NUM_FRAGMENTS_M
87-
a_tile = translate(warp_tile.MK, (M = (i-1)*COMPUTE_OP_SHAPE.M, K = 0))
87+
a_tile = translate_offset(warp_tile.MK, (M = (i-1)*COMPUTE_OP_SHAPE.M, K = 0))
8888
@inbounds a_frags[i] = transf_sh2rf_a(Operator.load_a(OPERATOR, SHARED_A_LAYOUT, shmem_a, a_tile), a_tile)
8989
end
9090

9191
# (3.3.2) Load a COMPUTE_WARP.K x COMPUTE_WARP.N tile of B from shared memory into registers
9292
b_frags = MArray{Tuple{NUM_FRAGMENTS_N}, Operator.fragtype_b(OPERATOR, SHARED_B_LAYOUT)}(undef)
9393

9494
@unroll for j = 1 : NUM_FRAGMENTS_N
95-
b_tile = translate(warp_tile.KN, (K = 0, N = (j-1)*COMPUTE_OP_SHAPE.N))
95+
b_tile = translate_offset(warp_tile.KN, (K = 0, N = (j-1)*COMPUTE_OP_SHAPE.N))
9696
@inbounds b_frags[j] = transf_sh2rf_b(Operator.load_b(OPERATOR, SHARED_B_LAYOUT, shmem_b, b_tile), b_tile)
9797
end
9898

@@ -114,7 +114,7 @@ function matmul_impl(a, b, c, d,
114114

115115
@unroll for i = 1 : NUM_FRAGMENTS_M
116116
@unroll for j = 1 : NUM_FRAGMENTS_N
117-
tile = translate(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
117+
tile = translate_offset(warp_tile, (M = (i-1)*COMPUTE_OP_SHAPE.M, N = (j-1)*COMPUTE_OP_SHAPE.N))
118118
Operator.store_d(OPERATOR, SHARED_D_LAYOUT, shmem_d, transf_rf2sh_d(c_frags[i, j], tile), tile)
119119
end
120120
end

src/device/matmul_kernels/operator.jl

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,31 +25,43 @@ struct WMMAOp{M, N, K} end
2525
@inline fragtype_b(::Type{WMMAOp{16, 16, 16}}, ::Type{Layout.AlignedColMajor{Float16}}) = WMMA.Fragment{16, 16, 16, 16, Float16, WMMA.ColMajor, WMMA.MatrixB}
2626
@inline fragtype_accum(::Type{WMMAOp{16, 16, 16}}, ::Type{Layout.AlignedColMajor{Float32}}) = WMMA.Fragment{16, 16, 16, 8, Float32, WMMA.Unspecified, WMMA.Accumulator}
2727

28-
function load_a(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float16}}, workspace, tile::Tile) where {M, N, K}
28+
@inline function load_a(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float16}}, workspace, tile::Tile) where {M, N, K}
2929
conf = WMMA.Config{M, N, K, Float32}
30-
linear_index = linearise(tile.index, size(workspace))
31-
ptr = pointer(workspace, linear_index)
30+
31+
linear_base = linearise(tile.base, size(workspace))
32+
linear_offset = linearise(tile.offset, size(workspace))
33+
34+
ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(Float16)
3235
return WMMA.load_a(ptr, size(workspace, 1), WMMA.ColMajor, conf)
3336
end
3437

35-
function load_b(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float16}}, workspace, tile::Tile) where {M, N, K}
38+
@inline function load_b(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float16}}, workspace, tile::Tile) where {M, N, K}
3639
conf = WMMA.Config{M, N, K, Float32}
37-
linear_index = linearise(tile.index, size(workspace))
38-
ptr = pointer(workspace, linear_index)
40+
41+
linear_base = linearise(tile.base, size(workspace))
42+
linear_offset = linearise(tile.offset, size(workspace))
43+
44+
ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(Float16)
3945
return WMMA.load_b(ptr, size(workspace, 1), WMMA.ColMajor, conf)
4046
end
4147

42-
function load_c(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float32}}, workspace, tile::Tile) where {M, N, K}
48+
@inline function load_c(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float32}}, workspace, tile::Tile) where {M, N, K}
4349
conf = WMMA.Config{M, N, K, Float32}
44-
linear_index = linearise(tile.index, size(workspace))
45-
ptr = pointer(workspace, linear_index)
50+
51+
linear_base = linearise(tile.base, size(workspace))
52+
linear_offset = linearise(tile.offset, size(workspace))
53+
54+
ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(Float32)
4655
return WMMA.load_c(ptr, size(workspace, 1), WMMA.ColMajor, conf)
4756
end
4857

49-
function store_d(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float32}}, workspace, frag, tile::Tile) where {M, N, K}
58+
@inline function store_d(::Type{WMMAOp{M, N, K}}, ::Type{Layout.AlignedColMajor{Float32}}, workspace, frag, tile::Tile) where {M, N, K}
5059
conf = WMMA.Config{M, N, K, Float32}
51-
linear_index = linearise(tile.index, size(workspace))
52-
ptr = pointer(workspace, linear_index)
60+
61+
linear_base = linearise(tile.base, size(workspace))
62+
linear_offset = linearise(tile.offset, size(workspace))
63+
64+
ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(Float32)
5365
WMMA.store_d(ptr, frag, size(workspace, 1), WMMA.ColMajor, conf)
5466
end
5567

src/device/tiling.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,15 @@ end
132132

133133
@inline translate(tile::Tile{size, names, T}, offset::Tuple) where {names, T, size} = translate(tile, NamedTuple{names}(offset))
134134

135+
export translate_offset
136+
137+
@inline function translate_offset(tile::Tile{size, names, T}, offset::NamedTuple{names, T}) where {names, T, size}
138+
new_offset = map(+, tile.offset, offset)
139+
return Tile{size, names, T}(tile.base, new_offset)
140+
end
141+
142+
@inline translate_offset(tile::Tile{size, names, T}, offset::Tuple) where {names, T, size} = translate_offset(tile, NamedTuple{names}(offset))
143+
135144
# -------------
136145
# TileIterators
137146
# -------------

0 commit comments

Comments
 (0)