@@ -25,43 +25,31 @@ struct WMMAOp{M, N, K} end
2525@inline fragtype_b (:: Type{WMMAOp{16, 16, 16}} , :: Type{Layout.AlignedColMajor{Float16}} ) = WMMA. Fragment{16 , 16 , 16 , 16 , Float16, WMMA. ColMajor, WMMA. MatrixB}
2626@inline fragtype_accum (:: Type{WMMAOp{16, 16, 16}} , :: Type{Layout.AlignedColMajor{Float32}} ) = WMMA. Fragment{16 , 16 , 16 , 8 , Float32, WMMA. Unspecified, WMMA. Accumulator}
2727
28- @inline function load_a (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float16}} , workspace, tile:: Tile ) where {M, N, K}
28+ function load_a (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float16}} , workspace, tile:: Tile ) where {M, N, K}
2929 conf = WMMA. Config{M, N, K, Float32}
30-
31- linear_base = linearise (tile. base, size (workspace))
32- linear_offset = linearise (tile. offset, size (workspace))
33-
34- ptr = pointer (workspace, linear_base) + (linear_offset - 1 ) * sizeof (Float16)
30+ linear_index = linearise (tile. index, size (workspace))
31+ ptr = pointer (workspace, linear_index)
3532 return WMMA. load_a (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
3633end
3734
38- @inline function load_b (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float16}} , workspace, tile:: Tile ) where {M, N, K}
35+ function load_b (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float16}} , workspace, tile:: Tile ) where {M, N, K}
3936 conf = WMMA. Config{M, N, K, Float32}
40-
41- linear_base = linearise (tile. base, size (workspace))
42- linear_offset = linearise (tile. offset, size (workspace))
43-
44- ptr = pointer (workspace, linear_base) + (linear_offset - 1 ) * sizeof (Float16)
37+ linear_index = linearise (tile. index, size (workspace))
38+ ptr = pointer (workspace, linear_index)
4539 return WMMA. load_b (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
4640end
4741
48- @inline function load_c (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float32}} , workspace, tile:: Tile ) where {M, N, K}
42+ function load_c (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float32}} , workspace, tile:: Tile ) where {M, N, K}
4943 conf = WMMA. Config{M, N, K, Float32}
50-
51- linear_base = linearise (tile. base, size (workspace))
52- linear_offset = linearise (tile. offset, size (workspace))
53-
54- ptr = pointer (workspace, linear_base) + (linear_offset - 1 ) * sizeof (Float32)
44+ linear_index = linearise (tile. index, size (workspace))
45+ ptr = pointer (workspace, linear_index)
5546 return WMMA. load_c (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
5647end
5748
58- @inline function store_d (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float32}} , workspace, frag, tile:: Tile ) where {M, N, K}
49+ function store_d (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float32}} , workspace, frag, tile:: Tile ) where {M, N, K}
5950 conf = WMMA. Config{M, N, K, Float32}
60-
61- linear_base = linearise (tile. base, size (workspace))
62- linear_offset = linearise (tile. offset, size (workspace))
63-
64- ptr = pointer (workspace, linear_base) + (linear_offset - 1 ) * sizeof (Float32)
51+ linear_index = linearise (tile. index, size (workspace))
52+ ptr = pointer (workspace, linear_index)
6553 WMMA. store_d (ptr, frag, size (workspace, 1 ), WMMA. ColMajor, conf)
6654end
6755
0 commit comments