@@ -27,32 +27,28 @@ struct WMMAOp{M, N, K} end
2727
2828function load_a (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float16}} , workspace, tile:: Tile ) where {M, N, K}
2929 conf = WMMA. Config{M, N, K, Float32}
30- ind = Tuple (tile. index) .+ 1
31- @inbounds linear_index = LinearIndices (size (workspace))[ind... ]
30+ linear_index = linearise (tile. index, size (workspace))
3231 ptr = pointer (workspace, linear_index)
3332 return WMMA. load_a (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
3433end
3534
3635function load_b (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float16}} , workspace, tile:: Tile ) where {M, N, K}
3736 conf = WMMA. Config{M, N, K, Float32}
38- ind = Tuple (tile. index) .+ 1
39- @inbounds linear_index = LinearIndices (size (workspace))[ind... ]
37+ linear_index = linearise (tile. index, size (workspace))
4038 ptr = pointer (workspace, linear_index)
4139 return WMMA. load_b (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
4240end
4341
4442function load_c (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float32}} , workspace, tile:: Tile ) where {M, N, K}
4543 conf = WMMA. Config{M, N, K, Float32}
46- ind = Tuple (tile. index) .+ 1
47- @inbounds linear_index = LinearIndices (size (workspace))[ind... ]
44+ linear_index = linearise (tile. index, size (workspace))
4845 ptr = pointer (workspace, linear_index)
4946 return WMMA. load_c (ptr, size (workspace, 1 ), WMMA. ColMajor, conf)
5047end
5148
5249function store_d (:: Type{WMMAOp{M, N, K}} , :: Type{Layout.AlignedColMajor{Float32}} , workspace, frag, tile:: Tile ) where {M, N, K}
5350 conf = WMMA. Config{M, N, K, Float32}
54- ind = Tuple (tile. index) .+ 1
55- @inbounds linear_index = LinearIndices (size (workspace))[ind... ]
51+ linear_index = linearise (tile. index, size (workspace))
5652 ptr = pointer (workspace, linear_index)
5753 WMMA. store_d (ptr, frag, size (workspace, 1 ), WMMA. ColMajor, conf)
5854end
0 commit comments