@@ -25,31 +25,43 @@ struct WMMAOp{M, N, K} end
2525@inline fragtype_b (:: Type{WMMAOp{16, 16, 16}} , :: Type{Layout.AlignedColMajor{Float16}} ) = WMMA. Fragment{16 , 16 , 16 , 16 , Float16, WMMA. ColMajor, WMMA. MatrixB}
2626@inline fragtype_accum (:: Type{WMMAOp{16, 16, 16}} , :: Type{Layout.AlignedColMajor{Float32}} ) = WMMA. Fragment{16 , 16 , 16 , 8 , Float32, WMMA. Unspecified, WMMA. Accumulator}
2727
28- function load_a (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float16}} , workspace, tile:: Tile ) where {M, N, K}
28+ @inline function load_a (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float16}} , workspace, tile:: Tile ) where {M, N, K}
2929 conf = WMMA. Config{M, N, K, Float32}
30- linear_index = linearise (tile. index, size (workspace))
31- ptr = pointer (workspace, linear_index)
30+
31+ linear_base = linearise (tile. base, size (workspace))
32+ linear_offset = linearise (tile. offset, size (workspace))
33+
34+ ptr = pointer (workspace, linear_base) + (linear_offset - 1 ) * sizeof (Float16)
3235 return WMMA. load_a (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
3336end
3437
35- function load_b (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float16}} , workspace, tile:: Tile ) where {M, N, K}
38+ @inline function load_b (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float16}} , workspace, tile:: Tile ) where {M, N, K}
3639 conf = WMMA. Config{M, N, K, Float32}
37- linear_index = linearise (tile. index, size (workspace))
38- ptr = pointer (workspace, linear_index)
40+
41+ linear_base = linearise (tile. base, size (workspace))
42+ linear_offset = linearise (tile. offset, size (workspace))
43+
44+ ptr = pointer (workspace, linear_base) + (linear_offset - 1 ) * sizeof (Float16)
3945 return WMMA. load_b (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
4046end
4147
42- function load_c (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float32}} , workspace, tile:: Tile ) where {M, N, K}
48+ @inline function load_c (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float32}} , workspace, tile:: Tile ) where {M, N, K}
4349 conf = WMMA. Config{M, N, K, Float32}
44- linear_index = linearise (tile. index, size (workspace))
45- ptr = pointer (workspace, linear_index)
50+
51+ linear_base = linearise (tile. base, size (workspace))
52+ linear_offset = linearise (tile. offset, size (workspace))
53+
54+ ptr = pointer (workspace, linear_base) + (linear_offset - 1 ) * sizeof (Float32)
4655 return WMMA. load_c (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
4756end
4857
49- function store_d (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float32}} , workspace, frag, tile:: Tile ) where {M, N, K}
58+ @inline function store_d (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float32}} , workspace, frag, tile:: Tile ) where {M, N, K}
5059 conf = WMMA. Config{M, N, K, Float32}
51- linear_index = linearise (tile. index, size (workspace))
52- ptr = pointer (workspace, linear_index)
60+
61+ linear_base = linearise (tile. base, size (workspace))
62+ linear_offset = linearise (tile. offset, size (workspace))
63+
64+ ptr = pointer (workspace, linear_base) + (linear_offset - 1 ) * sizeof (Float32)
5365 WMMA. store_d (ptr, frag, size (workspace, 1 ), WMMA. ColMajor, conf)
5466end
5567
0 commit comments