diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index 135888c23d8b..f912e482761c 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -178,23 +178,24 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems, } else { ib_.Begin(spv::OpTypeRuntimeArray).AddSeq(arr_type, value_type).Commit(&global_); } - int nbits = value_type.type.bits() * value_type.type.lanes(); - TVM_FFI_ICHECK_EQ(nbits % 8, 0); - uint32_t nbytes = static_cast(nbits) / 8; - // decorate the array type. - this->Decorate(spv::OpDecorate, arr_type, spv::DecorationArrayStride, nbytes); + if (interface_block) { + int nbits = value_type.type.bits() * value_type.type.lanes(); + TVM_FFI_ICHECK_EQ(nbits % 8, 0); + uint32_t nbytes = static_cast(nbits) / 8; + // Explicit layout is required for descriptor-backed interface blocks. + this->Decorate(spv::OpDecorate, arr_type, spv::DecorationArrayStride, nbytes); + } // declare struct of array SType struct_type; struct_type.id = id_counter_++; struct_type.type = DataType::Handle(); struct_type.element_type_id = value_type.id; ib_.Begin(spv::OpTypeStruct).AddSeq(struct_type, arr_type).Commit(&global_); - // decorate the array type. - ib_.Begin(spv::OpMemberDecorate) - .AddSeq(struct_type, 0, spv::DecorationOffset, 0) - .Commit(&decorate_); if (interface_block) { + ib_.Begin(spv::OpMemberDecorate) + .AddSeq(struct_type, 0, spv::DecorationOffset, 0) + .Commit(&decorate_); // Runtime array are always decorated as Block or BufferBlock // (shader storage buffer) if (spirv_support_.supports_storage_buffer_storage_class) { diff --git a/tests/python/codegen/test_target_codegen_vulkan.py b/tests/python/codegen/test_target_codegen_vulkan.py index 38830ae96f30..9af08c1a04bb 100644 --- a/tests/python/codegen/test_target_codegen_vulkan.py +++ b/tests/python/codegen/test_target_codegen_vulkan.py @@ -515,6 +515,24 @@ def kernel(): vulkan_codegen(Module, target) +@tvm.testing.requires_vulkan(support_required="compile-only") +def test_codegen_static_shared_memory(): + """The codegen should accept static shared/workgroup allocations.""" + + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): + A_shared = T.alloc_buffer((128,), dtype="float32", scope="shared") + + for bx in T.thread_binding(1, thread="blockIdx.x"): + for tx in T.thread_binding(128, thread="threadIdx.x"): + A_shared[tx] = A[tx] + B[tx] = A_shared[tx] + + tvm.compile(Module, target="vulkan") + + @tvm.testing.requires_gpu @tvm.testing.requires_vulkan def test_unary():