diff --git a/sycl/doc/design/spirv-extensions/SPV_INTEL_subgroup_scaled_matrix_multiply_accumulate.asciidoc b/sycl/doc/design/spirv-extensions/SPV_INTEL_subgroup_scaled_matrix_multiply_accumulate.asciidoc new file mode 100644 index 0000000000000..1fe1d16944277 --- /dev/null +++ b/sycl/doc/design/spirv-extensions/SPV_INTEL_subgroup_scaled_matrix_multiply_accumulate.asciidoc @@ -0,0 +1,532 @@ +:extension_name: SPV_INTEL_{zwsp}subgroup_{zwsp}scaled_{zwsp}matrix_{zwsp}multiply_{zwsp}accumulate +:capability_name: Subgroup{zwsp}Scaled{zwsp}Matrix{zwsp}Multiply{zwsp}Accumulate{zwsp}INTEL +:capability_token: 6263 +:op_name_scaled_mma: OpSubgroupScaledMatrixMultiplyAccumulateINTEL +:op_token_scaled_mma: 6264 +:ocp_microscaling_formats_url: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + +:MatrixResultBFloat16INTEL: Matrix{zwsp}Result{zwsp}BFloat16{zwsp}INTEL + +:MatrixAPackedFloat16INTEL: MatrixA{zwsp}Packed{zwsp}Float16{zwsp}INTEL +:MatrixAPackedBFloat16INTEL: MatrixA{zwsp}Packed{zwsp}BFloat16{zwsp}INTEL +:MatrixAPackedFloat8E4M3INTEL: MatrixA{zwsp}Packed{zwsp}Float8{zwsp}E4M3{zwsp}INTEL +:MatrixAPackedFloat8E5M2INTEL: MatrixA{zwsp}Packed{zwsp}Float8{zwsp}E5M2{zwsp}INTEL +:MatrixAPackedFloat4E2M1INTEL: MatrixA{zwsp}Packed{zwsp}Float4{zwsp}E2M1{zwsp}INTEL + +:MatrixBPackedFloat16INTEL: MatrixB{zwsp}Packed{zwsp}Float16{zwsp}INTEL +:MatrixBPackedBFloat16INTEL: MatrixB{zwsp}Packed{zwsp}BFloat16{zwsp}INTEL +:MatrixBPackedFloat8E4M3INTEL: MatrixB{zwsp}Packed{zwsp}Float8{zwsp}E4M3{zwsp}INTEL +:MatrixBPackedFloat8E5M2INTEL: MatrixB{zwsp}Packed{zwsp}Float8{zwsp}E5M2{zwsp}INTEL +:MatrixBPackedFloat4E2M1INTEL: MatrixB{zwsp}Packed{zwsp}Float4{zwsp}E2M1{zwsp}INTEL + +:MatrixCBFloat16INTEL: MatrixC{zwsp}BFloat16{zwsp}INTEL + +:ScaleAPackedFloat8E8M0INTEL: ScaleA{zwsp}Packed{zwsp}Float8{zwsp}E8M0{zwsp}INTEL +:ScaleBPackedFloat8E8M0INTEL: ScaleB{zwsp}Packed{zwsp}Float8{zwsp}E8M0{zwsp}INTEL + +{extension_name} +================ + +== Name Strings + +{extension_name} + +== Contact + +To report problems with this extension, please open a new issue at: + +https://github.com/intel/llvm + +// TODO: Switch to the SPIR-V registry for publication. + +== Contributors + +// spell-checker: disable +* Ben Ashbaugh, Intel +* Dmitry Sidorov, Intel +// spell-checker: enable + +== Notice + +Copyright (c) 2025 Intel Corporation. All rights reserved. + +== Status + +* Working Draft + +This is a preview extension specification, intended to provide early access to a feature for review and community feedback. When the feature matures, this specification may be released as a formal extension. + +Because the interfaces defined by this specification are not final and are subject to change they are not intended to be used by shipping software products. If you are interested in using this feature in your software product, please let us know! + +== Version + +[width="40%",cols="25,25"] +|======================================== +| Last Modified Date | 2025-11-13 +| Revision | 1 +|======================================== + +== Dependencies + +This extension is written against the SPIR-V Specification, +Version 1.6, Revision 6. + +This extension requires SPIR-V 1.0. + +This extension depends on the *SPV_INTEL_subgroup_matrix_multiply_accumulate* extension and other related extensions. +In particular, it reuses the Matrix Multiply Accumulate Operands defined by these extensions to describe how the matrix operands should be interpreted. + +== Overview + +This extension adds an instruction to compute the matrix product of an M x K matrix (referred to as _Matrix A_ in this extension) with a K x N matrix (_Matrix B_) and then add an M x N matrix (_Matrix C_), with scaling factors applied to the elements of _Matrix A_ and _Matrix B_. + +This instruction is similar to the instruction added by the *SPV_INTEL_subgroup_matrix_multiply_accumulate* extension, with the scaling factors allowing low-precision matrix elements to represent a wider range of values. + +For additional background information about the formats described by this extension and the scaled matrix product, please refer to the {ocp_microscaling_formats_url}[OCP Microscaling Formats (MX) Specification, Version 1]. + +== Extension Name + +To use this extension within a SPIR-V module, the appropriate *OpExtension* must +be present in the module: + +[subs="attributes"] +---- +OpExtension "{extension_name}" +---- + +== Modifications to the SPIR-V Specification, Version 1.6 + +=== Capabilities + +Modify Section 3.31, Capability, adding rows to the Capability table: +-- +[cols="^.^2,16,15",options="header",width = "100%"] +|==== +2+^.^| Capability | Implicitly Declares +| {capability_token} | *{capability_name}* | +|==== +-- + +=== Matrix Multiply Accumulate Operands + +Modify Section 3.2.53, Matrix Multiply Accumulate Operands, which may also be found in the *SPV_INTEL_subgroup_matrix_multiply_accumulate* extension specification, adding rows to the table: + +[cols="^.^4,16,15",options="header",width = "100%"] +|==== +2+^.^| Matrix Multiply Accumulate Operands | Enabling Capabilities + +// Only valid for integer operand types: +| 0x100000 | *ScaleAFloat8E8M0INTEL* + +The scale factor for matrix A is interpreted as fp8 E8M0 data. | +| 0x200000 | *ScaleBFloat8E8M0INTEL* + +The scale factor for matrix B is interpreted as fp8 E8M0 data. | + +|==== + +=== Instructions + +Modify Section 3.42.21, Group Instructions, adding to the end of the list of instructions: + +[cols="1,1,9*3",width="100%"] +|===== +8+a|[[{op_name_scaled_mma}]]*{op_name_scaled_mma}* + +Computes the scaled matrix product of two matrix operands and adds a third matrix operand. +All invocations in the subgroup cooperate to perform this operation. + +_Result Type_ defines the result of the scaled matrix multiply accumulate operation. +It must be a scalar or vector of floating-point or integer type. +The number of components in _Result Type_ defines the _M_ dimension of the matrix multiply accumulate operation. +If _Result Type_ is a scalar type, the _M_ dimension is one. + +_K Dim_ defines the _K_ dimension of the scaled matrix multiply accumulate operation. +It must come from a constant instruction with scalar 32-bit integer type. + +The _N_ dimension of the scaled matrix multiply accumulate operation is implicitly the number of invocations in the subgroup. + +_Matrix A_ is the first matrix operand and has _M_ rows and _K_ columns. +The type of _Matrix A_ must be a scalar or vector of floating-point or integer type. +Multiple invocations in the subgroup may contribute part of the _Matrix A_ operand, depending on the matrix operand size and the subgroup size. + +_Matrix B_ is the second matrix operand and has _K_ rows and _N_ columns. +It must be a scalar or vector of floating-point or integer type. +Each of the invocations in the subgroup contributes part of the _Matrix B_ operand. + +_Matrix C_ is the third matrix operand and has _M_ rows and _N_ columns. +It must be a scalar or vector of floating-point or integer type. +Each of the invocations in the subgroup contributes part of the _Matrix C_ operand. + +_Scale A_ is a scalar factor that is applied to elements of _Matrix A_. +It must be a scalar or vector of floating-point or integer type. +Each invocation may contribute part of the _Scale A_ operand, depending on the matrix operand size and the subgroup size. + +_Scale B_ is a scalar factor that is applied to elements of _Matrix B_. +It must be a scalar or vector of floating-point or integer type. +Each of the invocations in the subgroup offers part of the _Scale B_ operand. + +The multiplication step of the scaled matrix multiply accumulate operation computes the scaled matrix product of _Matrix A_ and _Matrix B_. +The product is a matrix with _M_ rows and _N_ columns. +The order of operations to compute the elements of the matrix product is implementation-dependent. + +For integer matrices, the operations used for the multiplication of _Matrix A_ and _Matrix B_ and the addition of _Matrix C_ are performed at the precision of the _Result Type_. +The resulting value will equal the low-order N bits of the correct result R, where N is the result width and R is computed with enough precision to avoid overflow and underflow. + +For floating-point matrices, the precision and the order of operations are implementation-defined. + +The accumulation step of the scaled matrix multiply accumulate operation computes the element-wise addition of the scaled matrix product of _Matrix A_ and _Matrix B_ with _Matrix C_. +The final result is a matrix with _M_ rows and _N_ columns, which is assigned to _Result_. + +_Matrix Multiply Accumulate Operands_ is an optional literal that specifies additional information about the matrix operands, such as ways to reinterpret the bits passed as the matrix operands or scale factors. +If _Matrix Multiply Accumulate Operands_ is not present, it is the same as specifying the _Matrix Multiply Accumulate Operand_ *None*. + +Behavior is undefined unless all invocations within the subgroup execute the same dynamic instance of this instruction. + +3+a|Capability: + +*{capability_name}* +| 7 + variable | {op_token_scaled_mma} +| __ + +_Result Type_ +| __ + +_Result_ +| __ + +_K Dim_ +| __ + +_Matrix A_ +| __ + +_Matrix B_ +| __ + +_Matrix C_ +| __ + +_Scale A_ +| __ + +_Scale B_ +| Optional + +_Matrix Multiply Accumulate Operands_ +|===== + +== Mapping Scale Factors to Invocations + +The _Scale A_ factors for the _M_ rows of _Matrix A_ are provided by the _M_ lower-numbered invocations in the subgroup, with each invocation providing the scale factors for one row of _Matrix A_. +If a row of _Matrix A_ requires two or more scale factors, the scale factors will be provided as a vector, with scale factor for the first set of matrix elements in the row provided by the first component, and scale factor for the next set of matrix elements in the row provided by the next component, and so on. + +The _Scale B_ factors for the _N_ columns of _Matrix B_ are provided by the _N_ lower-numbered invocations in the subgroup, with each invocation providing the scale factors for one column of _Matrix B_. +If a column of _Matrix B_ requires two or more scale factors, the scale factors will be provided as a vector, with the first set of scale factors for the column provided by the first component, and the next set of scale factors for the column provided by the next component, and so on. + +== Supported Matrix Dimensions and Types + +[NOTE] +==== +This section will be moved to a client API specification before final publication, but is included in this SPIR-V extension for now for ease of review. +==== + +For devices where the minimum subgroup size is 16, the following matrix dimensions and types are supported when the subgroup size is 16. +Behavior is undefined if these combinations are used on other devices or from kernels with a different subgroup size: + +[cols="^1a,^1a,^1a,^1a,^2a,^2a,^2a,^2a,^2a,^2a",width="100%"] +[options="header"] +|===== +| Sub-group Size | M Dim | N Dim | K Dim | Result Type | Matrix A Type | Matrix B Type | Matrix C Type | Scale A Type | Scale B Type + +// bdpas reference: https://gfxspecs.intel.com/Predator/Home/Index/74737 + +// f32 = fp16 x fp16 + f32 +10+<| *fp16 matrix sources, fp32 accumulator, fp8 scale factors (e8m0)*: +| 16 | 8 | 16 | 16 | `8 x float32_t` +| `8 x int16_t` with *{MatrixAPackedFloat16INTEL}* +| `8 x int32_t` with *{MatrixBPackedFloat16INTEL}* +| `8 x float32_t` +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +// f32 = bf16 x bf16 + f32 +10+<| *bf16 matrix sources, fp32 accumulator, fp8 scale factors (e8m0)*: +| 16 | 8 | 16 | 16 | `8 x float32_t` +| `8 x int16_t` with *{MatrixAPackedBFloat16INTEL}* +| `8 x int32_t` with *{MatrixBPackedBFloat16INTEL}* +| `8 x float32_t` +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +// f16 = fp16 x fp16 + f16 +10+<| *fp16 matrix sources, fp16 accumulator, fp8 scale factors (e8m0)*: +| 16 | 8 | 16 | 16 | `8 x float16_t` +| `8 x int16_t` with *{MatrixAPackedFloat16INTEL}* +| `8 x int32_t` with *{MatrixBPackedFloat16INTEL}* +| `8 x float16_t` +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +// bf16 = bf16 x bf16 + bf16 +10+<| *bf16 matrix sources, bf16 accumulator, fp8 scale factors (e8m0)*: +| 16 | 8 | 16 | 16 | `8 x int16_t` with *{MatrixResultBFloat16INTEL}* +| `8 x int16_t` with *{MatrixAPackedBFloat16INTEL}* +| `8 x int32_t` with *{MatrixBPackedBFloat16INTEL}* +| `8 x int16_t` with *{MatrixCBFloat16INTEL}* +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +// f32 = f8e4m3 x f8e4m3 + f32 +// f32 = f8e4m3 x f8e5m2 + f32 +// f32 = f8e5m2 x f8e4m3 + f32 +// f32 = f8e5m2 x f8e5m2 + f32 +10+<| *fp8 matrix sources (e4m3 and e5m2), fp32 accumulator, fp8 scale factors (e8m0)*: +| 16 | 8 | 16 | 32 | `8 x float32_t` +| `8 x int16_t` with *{MatrixAPackedFloat8E4M3INTEL}* +| `8 x int32_t` with *{MatrixBPackedFloat8E4M3INTEL}* +| `8 x float32_t` +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +| 16 | 8 | 16 | 32 | `8 x float32_t` +| `8 x int16_t` with *{MatrixAPackedFloat8E4M3INTEL}* +| `8 x int32_t` with *{MatrixBPackedFloat8E5M2INTEL}* +| `8 x float32_t` +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +| 16 | 8 | 16 | 32 | `8 x float32_t` +| `8 x int16_t` with *{MatrixAPackedFloat8E5M2INTEL}* +| `8 x int32_t` with *{MatrixBPackedFloat8E4M3INTEL}* +| `8 x float32_t` +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +| 16 | 8 | 16 | 32 | `8 x float32_t` +| `8 x int16_t` with *{MatrixAPackedFloat8E5M2INTEL}* +| `8 x int32_t` with *{MatrixBPackedFloat8E5M2INTEL}* +| `8 x float32_t` +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +// bf16 = f8e4m3 x f8e4m3 + bf16 +// bf16 = f8e4m3 x f8e5m2 + bf16 +// bf16 = f8e5m2 x f8e4m3 + bf16 +// bf16 = f8e5m2 x f8e5m2 + bf16 +10+<| *fp8 matrix sources (e4m3 and e5m2), bf16 accumulator, fp8 scale factors (e8m0)*: +| 16 | 8 | 16 | 32 | `8 x int16_t` with *{MatrixResultBFloat16INTEL}* +| `8 x int16_t` with *{MatrixAPackedFloat8E4M3INTEL}* +| `8 x int32_t` with *{MatrixBPackedFloat8E4M3INTEL}* +| `8 x int16_t` with *{MatrixCBFloat16INTEL}* +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +| 16 | 8 | 16 | 32 | `8 x int16_t` with *{MatrixResultBFloat16INTEL}* +| `8 x int16_t` with *{MatrixAPackedFloat8E4M3INTEL}* +| `8 x int32_t` with *{MatrixBPackedFloat8E5M2INTEL}* +| `8 x int16_t` with *{MatrixCBFloat16INTEL}* +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +| 16 | 8 | 16 | 32 | `8 x int16_t` with *{MatrixResultBFloat16INTEL}* +| `8 x int16_t` with *{MatrixAPackedFloat8E5M2INTEL}* +| `8 x int32_t` with *{MatrixBPackedFloat8E4M3INTEL}* +| `8 x int16_t` with *{MatrixCBFloat16INTEL}* +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +| 16 | 8 | 16 | 32 | `8 x int16_t` with *{MatrixResultBFloat16INTEL}* +| `8 x int16_t` with *{MatrixAPackedFloat8E5M2INTEL}* +| `8 x int32_t` with *{MatrixBPackedFloat8E5M2INTEL}* +| `8 x int16_t` with *{MatrixCBFloat16INTEL}* +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +// f32 = f4e2m1 x f4e2m1 + f32 +10+<| *fp4 matrix sources (e2m1), fp32 accumulator, fp8 scale factors (e8m0)*: +| 16 | 8 | 16 | 64 | `8 x float32_t` +| `8 x int16_t` with *{MatrixAPackedFloat4E2M1INTEL}* +| `8 x int32_t` with *{MatrixBPackedFloat4E2M1INTEL}* +| `8 x float32_t` +| `2 x uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `2 x uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +// bf16 = f4e2m1 x f4e2m1 + bf16 +10+<| *fp4 matrix sources (e2m1), bf16 accumulator, fp8 scale factors (e8m0)*: +| 16 | 8 | 16 | 64 | `8 x int16_t` with *{MatrixResultBFloat16INTEL}* +| `8 x int16_t` with *{MatrixAPackedFloat4E2M1INTEL}* +| `8 x int32_t` with *{MatrixBPackedFloat4E2M1INTEL}* +| `8 x int16_t` with *{MatrixCBFloat16INTEL}* +| `2 x uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `2 x uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +|===== + +For devices where the minimum subgroup size is 16, the following matrix dimensions and types are supported when the subgroup size is 32. + +When the subgroup size is 32, each invocation is responsible for either the even or odd rows of the matrix sources or result matrix, therefore the number of matrix rows M must be even. +The 16 invocations with the smallest subgroup local invocation IDs are responsible for the even matrix rows, starting from row zero, and the 16 invocations with the largest subgroup local invocation IDs are responsible for the odd matrix rows, starting from row one: + +Behavior is undefined if these combinations are used on other devices or from kernels with a different subgroup size: + +[cols="^1a,^1a,^1a,^1a,^2a,^2a,^2a,^2a,^2a,^2a",width="100%"] +[options="header"] +|===== +| Sub-group Size | M Dim | N Dim | K Dim | Result Type | Matrix A Type | Matrix B Type | Matrix C Type | Scale A Type | Scale B Type + +// bdpas reference: https://gfxspecs.intel.com/Predator/Home/Index/74737 + +// f32 = fp16 x fp16 + f32 +10+<| *fp16 matrix sources, fp32 accumulator, fp8 scale factors (e8m0)*: +| 32 | 8 | 16 | 16 | `4 x float32_t` +| `4 x int16_t` with *{MatrixAPackedFloat16INTEL}* +| `4 x int32_t` with *{MatrixBPackedFloat16INTEL}* +| `4 x float32_t` +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +// f32 = bf16 x bf16 + f32 +10+<| *bf16 matrix sources, fp32 accumulator, fp8 scale factors (e8m0)*: +| 32 | 8 | 16 | 16 | `4 x float32_t` +| `4 x int16_t` with *{MatrixAPackedBFloat16INTEL}* +| `4 x int32_t` with *{MatrixBPackedBFloat16INTEL}* +| `4 x float32_t` +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +// f16 = fp16 x fp16 + f16 +10+<| *fp16 matrix sources, fp16 accumulator, fp8 scale factors (e8m0)*: +| 32 | 8 | 16 | 16 | `4 x float16_t` +| `4 x int16_t` with *{MatrixAPackedFloat16INTEL}* +| `4 x int32_t` with *{MatrixBPackedFloat16INTEL}* +| `4 x float16_t` +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +// bf16 = bf16 x bf16 + bf16 +10+<| *bf16 matrix sources, bf16 accumulator, fp8 scale factors (e8m0)*: +| 32 | 8 | 16 | 16 | `4 x int16_t` with *{MatrixResultBFloat16INTEL}* +| `4 x int16_t` with *{MatrixAPackedBFloat16INTEL}* +| `4 x int32_t` with *{MatrixBPackedBFloat16INTEL}* +| `4 x int16_t` with *{MatrixCBFloat16INTEL}* +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +// f32 = f8e4m3 x f8e4m3 + f32 +// f32 = f8e4m3 x f8e5m2 + f32 +// f32 = f8e5m2 x f8e4m3 + f32 +// f32 = f8e5m2 x f8e5m2 + f32 +10+<| *fp8 matrix sources (e4m3 and e5m2), fp32 accumulator, fp8 scale factors (e8m0)*: +| 32 | 8 | 16 | 32 | `4 x float32_t` +| `4 x int16_t` with *{MatrixAPackedFloat8E4M3INTEL}* +| `4 x int32_t` with *{MatrixBPackedFloat8E4M3INTEL}* +| `4 x float32_t` +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +| 32 | 8 | 16 | 32 | `4 x float32_t` +| `4 x int16_t` with *{MatrixAPackedFloat8E4M3INTEL}* +| `4 x int32_t` with *{MatrixBPackedFloat8E5M2INTEL}* +| `4 x float32_t` +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +| 32 | 8 | 16 | 32 | `4 x float32_t` +| `4 x int16_t` with *{MatrixAPackedFloat8E5M2INTEL}* +| `4 x int32_t` with *{MatrixBPackedFloat8E4M3INTEL}* +| `4 x float32_t` +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +| 32 | 8 | 16 | 32 | `4 x float32_t` +| `4 x int16_t` with *{MatrixAPackedFloat8E5M2INTEL}* +| `4 x int32_t` with *{MatrixBPackedFloat8E5M2INTEL}* +| `4 x float32_t` +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +// bf16 = f8e4m3 x f8e4m3 + bf16 +// bf16 = f8e4m3 x f8e5m2 + bf16 +// bf16 = f8e5m2 x f8e4m3 + bf16 +// bf16 = f8e5m2 x f8e5m2 + bf16 +10+<| *fp8 matrix sources (e4m3 and e5m2), bf16 accumulator, fp8 scale factors (e8m0)*: +| 32 | 8 | 16 | 32 | `4 x int16_t` with *{MatrixResultBFloat16INTEL}* +| `4 x int16_t` with *{MatrixAPackedFloat8E4M3INTEL}* +| `4 x int32_t` with *{MatrixBPackedFloat8E4M3INTEL}* +| `4 x int16_t` with *{MatrixCBFloat16INTEL}* +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +| 32 | 8 | 16 | 32 | `4 x int16_t` with *{MatrixResultBFloat16INTEL}* +| `4 x int16_t` with *{MatrixAPackedFloat8E4M3INTEL}* +| `4 x int32_t` with *{MatrixBPackedFloat8E5M2INTEL}* +| `4 x int16_t` with *{MatrixCBFloat16INTEL}* +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +| 32 | 8 | 16 | 32 | `4 x int16_t` with *{MatrixResultBFloat16INTEL}* +| `4 x int16_t` with *{MatrixAPackedFloat8E5M2INTEL}* +| `4 x int32_t` with *{MatrixBPackedFloat8E4M3INTEL}* +| `4 x int16_t` with *{MatrixCBFloat16INTEL}* +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +| 32 | 8 | 16 | 32 | `4 x int16_t` with *{MatrixResultBFloat16INTEL}* +| `4 x int16_t` with *{MatrixAPackedFloat8E5M2INTEL}* +| `4 x int32_t` with *{MatrixBPackedFloat8E5M2INTEL}* +| `4 x int16_t` with *{MatrixCBFloat16INTEL}* +| `uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +// f32 = f4e2m1 x f4e2m1 + f32 +10+<| *fp4 matrix sources (e2m1), fp32 accumulator, fp8 scale factors (e8m0)*: +| 32 | 8 | 16 | 64 | `4 x float32_t` +| `4 x int16_t` with *{MatrixAPackedFloat4E2M1INTEL}* +| `4 x int32_t` with *{MatrixBPackedFloat4E2M1INTEL}* +| `4 x float32_t` +| `2 x uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `2 x uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +// bf16 = f4e2m1 x f4e2m1 + bf16 +10+<| *fp4 matrix sources (e2m1), bf16 accumulator, fp8 scale factors (e8m0)*: +| 32 | 8 | 16 | 64 | `4 x int16_t` with *{MatrixResultBFloat16INTEL}* +| `4 x int16_t` with *{MatrixAPackedFloat4E2M1INTEL}* +| `4 x int32_t` with *{MatrixBPackedFloat4E2M1INTEL}* +| `4 x int16_t` with *{MatrixCBFloat16INTEL}* +| `2 x uint8_t` with *{ScaleAPackedFloat8E8M0INTEL}* +| `2 x uint8_t` with *{ScaleBPackedFloat8E8M0INTEL}* + +|===== + +== Issues + +. Should "subgroup" or "scaled" come first in the name of this extension, the new capability, and the new instruction? ++ +-- +`RESOLVED`: "Subgroup" will come before "scaled". +This is consistent with SPIR-V conventions for other extensions and instructions where the "group" part of the name comes first. +-- + +. Do we need an explicit operand to indicate the "Scaling Block Size"? ++ +-- +`RESOLVED`: No, we can currently infer the scaling block size from the element data type, so we will not include an explicit scaling block size operand. +-- + +. Do we need an explicit operand to indicate the M dimension, or can we continue to derive the M dimension from the result type? ++ +-- +`RESOLVED`: No, this is not necessary, because we currently only support cases with M equal to eight. +Therefore, we will retain consistency with the existing matrix multiply accumulate instruction, and we will not include an explicit M dimension operand. +-- + +. If there is more than one scaling factor per row or column, should they be represented as a vector, or as a packed scalar? ++ +-- +`RESOLVED`: We will represent multiple scaling factors as a vector. +-- + +. Do we need to support this extension with a sub-group size of 32, or is support for only a sub-group size of 16 sufficient? ++ +-- +`RESOLVED`: Although we currently plan only to support a sub-group size of 16, the extension will additionally work with a sub-group size of 32, if required. +-- + +== Revision History + +[cols="5,15,15,70"] +[grid="rows"] +[options="header"] +|======================================== +|Rev|Date|Author|Changes +|1|2025-11-13|Ben Ashbaugh|Initial revision for publication +|========================================