Skip to content

Commit 6428ca7

Browse files
mbelickiigcbot
authored andcommitted
Support for large JointMatrix slices.
This patch adds support for 32x64x16 MxNxK combination to JointMatrix extension.
1 parent bde6fc5 commit 6428ca7

File tree

3 files changed

+381
-32
lines changed

3 files changed

+381
-32
lines changed

IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ extern __constant int __JointMatrixLoadStoreOpt;
118118
#define ARR_TO_VEC1(type, arr) \
119119
arr[0]
120120

121+
#define OUT_VEC16(type) type##16
121122
#define OUT_VEC8(type) type##8
122123
#define OUT_VEC7(type) type##8
123124
#define OUT_VEC6(type) type##8
@@ -610,3 +611,190 @@ DEFINE_GET_COORD(Accumulator, , 32, 8, 8, 8x8, 8, 1)
610611
//bfloat16
611612
DEFINE_GET_COORD(PackedA, , 16, 8, 16, 8x16, 8, 1)
612613
DEFINE_GET_COORD(PackedB, , 16, 16, 8, 16x8, 8, 2)
614+
615+
/* experimental large slice support: */
616+
617+
INLINE void __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *raw_c_ptr, __private char *result) {
618+
short16 a = *(short16 *)a_ptr;
619+
int8 b = *(int8 *)b_ptr;
620+
int16 raw_c = *(int16 *)raw_c_ptr;
621+
622+
short8 a0 = (short8)(a.s0, a.s1, a.s2, a.s3, a.s4, a.s5, a.s6, a.s7);
623+
short8 a1 = (short8)(a.s8, a.s9, a.sa, a.sb, a.sc, a.sd, a.se, a.sf);
624+
625+
float16 c = *(float16 *)&raw_c;
626+
627+
float8 c0 = (float8)(c.s0, c.s1, c.s2, c.s3, c.s4, c.s5, c.s6, c.s7);
628+
float8 c1 = (float8)(c.s8, c.s9, c.sa, c.sb, c.sc, c.sd, c.se, c.sf);
629+
630+
float8 fres0 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_8(c0, a0, b);
631+
float8 fres1 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_8(c0, a1, b);
632+
633+
int8 res0 = *(int8 *)&fres0;
634+
int8 res1 = *(int8 *)&fres1;
635+
636+
__private int16 *dst = (__private int16 *)result;
637+
*dst = (int16)(res0.s0, res0.s1, res0.s2, res0.s3, res0.s4, res0.s5, res0.s6, res0.s7,
638+
res1.s0, res1.s1, res1.s2, res1.s3, res1.s4, res1.s5, res1.s6, res1.s7);
639+
}
640+
641+
INLINE void __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_fp16_fp16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *raw_c_ptr, __private char *result) {
642+
short16 a = *(short16 *)a_ptr;
643+
int8 b = *(int8 *)b_ptr;
644+
int16 raw_c = *(int16 *)raw_c_ptr;
645+
646+
short8 a0 = (short8)(a.s0, a.s1, a.s2, a.s3, a.s4, a.s5, a.s6, a.s7);
647+
short8 a1 = (short8)(a.s8, a.s9, a.sa, a.sb, a.sc, a.sd, a.se, a.sf);
648+
649+
float16 c = *(float16 *)&raw_c;
650+
651+
float8 c0 = (float8)(c.s0, c.s1, c.s2, c.s3, c.s4, c.s5, c.s6, c.s7);
652+
float8 c1 = (float8)(c.s8, c.s9, c.sa, c.sb, c.sc, c.sd, c.se, c.sf);
653+
654+
float8 fres0 = __builtin_IB_sub_group16_fdpas_f_f_hf_hf_8_8(c0, a0, b);
655+
float8 fres1 = __builtin_IB_sub_group16_fdpas_f_f_hf_hf_8_8(c0, a1, b);
656+
657+
int8 res0 = *(int8 *)&fres0;
658+
int8 res1 = *(int8 *)&fres1;
659+
660+
__private int16 *dst = (__private int16 *)result;
661+
*dst = (int16)(res0.s0, res0.s1, res0.s2, res0.s3, res0.s4, res0.s5, res0.s6, res0.s7,
662+
res1.s0, res1.s1, res1.s2, res1.s3, res1.s4, res1.s5, res1.s6, res1.s7);
663+
}
664+
665+
INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x64x16_bf16_bf16_fp32(__private char *a_ptr, __private char *b_ptr, __private char *c_ptr, __private char *d_ptr) {
666+
__private char *a0 = a_ptr;
667+
__private char *a1 = a_ptr + 16 * (sizeof (short));
668+
669+
__private char *b0 = b_ptr;
670+
__private char *b1 = b_ptr + 1 * 16 * (sizeof (short));
671+
__private char *b2 = b_ptr + 2 * 16 * (sizeof (short));
672+
__private char *b3 = b_ptr + 3 * 16 * (sizeof (short));
673+
674+
__private char *c0 = c_ptr + 0 * 16 * (sizeof (int));
675+
__private char *c1 = c_ptr + 1 * 16 * (sizeof (int));
676+
__private char *c2 = c_ptr + 2 * 16 * (sizeof (int));
677+
__private char *c3 = c_ptr + 3 * 16 * (sizeof (int));
678+
__private char *c4 = c_ptr + 4 * 16 * (sizeof (int));
679+
__private char *c5 = c_ptr + 5 * 16 * (sizeof (int));
680+
__private char *c6 = c_ptr + 6 * 16 * (sizeof (int));
681+
__private char *c7 = c_ptr + 7 * 16 * (sizeof (int));
682+
683+
__private char *d0 = d_ptr + 0 * 16 * (sizeof (int));
684+
__private char *d1 = d_ptr + 1 * 16 * (sizeof (int));
685+
__private char *d2 = d_ptr + 2 * 16 * (sizeof (int));
686+
__private char *d3 = d_ptr + 3 * 16 * (sizeof (int));
687+
__private char *d4 = d_ptr + 4 * 16 * (sizeof (int));
688+
__private char *d5 = d_ptr + 5 * 16 * (sizeof (int));
689+
__private char *d6 = d_ptr + 6 * 16 * (sizeof (int));
690+
__private char *d7 = d_ptr + 7 * 16 * (sizeof (int));
691+
692+
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a0, b0, c0, d0);
693+
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a0, b1, c1, d1);
694+
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a0, b2, c2, d2);
695+
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a0, b3, c3, d3);
696+
697+
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a1, b0, c4, d4);
698+
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a1, b1, c5, d5);
699+
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a1, b2, c6, d6);
700+
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32(a1, b3, c7, d7);
701+
}
702+
703+
INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32(__private char *dst, char *mem, long stride) {
704+
IMPLEMENT_BLOCK2D_LOAD_SG16_ROW_MAJOR(int, 32, int, 32, 16, 16, 16)
705+
}
706+
707+
INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_16x16_i16_generic_v8i8_pi32_i32(__private char *dst, char *mem, long stride) {
708+
IMPLEMENT_BLOCK2D_LOAD_SG16_ROW_MAJOR(short, 16, short, 16, 16, 16, 16)
709+
}
710+
711+
INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x64_i32_generic_v8i8_pi32_i32(__private char *dst, char *mem, long stride) {
712+
__private char *c0 = dst + 0 * 16 * (sizeof (int));
713+
__private char *c1 = dst + 1 * 16 * (sizeof (int));
714+
__private char *c2 = dst + 2 * 16 * (sizeof (int));
715+
__private char *c3 = dst + 3 * 16 * (sizeof (int));
716+
__private char *c4 = dst + 4 * 16 * (sizeof (int));
717+
__private char *c5 = dst + 5 * 16 * (sizeof (int));
718+
__private char *c6 = dst + 6 * 16 * (sizeof (int));
719+
__private char *c7 = dst + 7 * 16 * (sizeof (int));
720+
721+
char *mem0 = mem + 0 * 16 * (sizeof (int));
722+
char *mem1 = mem + 1 * 16 * (sizeof (int));
723+
char *mem2 = mem + 2 * 16 * (sizeof (int));
724+
char *mem3 = mem + 3 * 16 * (sizeof (int));
725+
char *mem4 = mem + 0 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride;
726+
char *mem5 = mem + 1 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride;
727+
char *mem6 = mem + 2 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride;
728+
char *mem7 = mem + 3 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride;
729+
730+
__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32(c0, mem0, stride);
731+
__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32(c1, mem1, stride);
732+
__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32(c2, mem2, stride);
733+
__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32(c3, mem3, stride);
734+
__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32(c4, mem4, stride);
735+
__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32(c5, mem5, stride);
736+
__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32(c6, mem6, stride);
737+
__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32(c7, mem7, stride);
738+
}
739+
740+
INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_32x16_i16_generic_v8i8_pi32_i32(__private char *dst, char *mem, long stride) {
741+
__private char *dst0 = dst;
742+
__private char *dst1 = dst + 16 * (sizeof (short));
743+
744+
char *mem0 = mem;
745+
char *mem1 = mem + 16 * (sizeof (short)) * stride;
746+
747+
__builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_16x16_i16_generic_v8i8_pi32_i32(dst0, mem0, stride);
748+
__builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_16x16_i16_generic_v8i8_pi32_i32(dst1, mem1, stride);
749+
}
750+
751+
INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x64_i16_generic_v8i8_pi32_i32(__private char *dst, char *mem, long stride) {
752+
__private char *b0 = dst;
753+
__private char *b1 = dst + 1 * 16 * (sizeof (short));
754+
__private char *b2 = dst + 2 * 16 * (sizeof (short));
755+
__private char *b3 = dst + 3 * 16 * (sizeof (short));
756+
757+
char *mem0 = mem + 0 * 16 * (sizeof (int));
758+
char *mem1 = mem + 1 * 16 * (sizeof (int));
759+
char *mem2 = mem + 2 * 16 * (sizeof (int));
760+
char *mem3 = mem + 3 * 16 * (sizeof (int));
761+
762+
__builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_generic_v8i8_pi32_i32(b0, mem0, stride);
763+
__builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_generic_v8i8_pi32_i32(b1, mem1, stride);
764+
__builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_generic_v8i8_pi32_i32(b2, mem2, stride);
765+
__builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_generic_v8i8_pi32_i32(b3, mem3, stride);
766+
}
767+
768+
INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8(char *mem, __private char *src, long stride) {
769+
IMPLEMENT_BLOCK2D_STORE_SG16(int, int, 32, 16, 16, slice)
770+
}
771+
772+
INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_32x64_i32_generic_pi64_v8i8(char *mem, __private char *src, long stride) {
773+
__private char *c0 = src + 0 * 16 * (sizeof (int));
774+
__private char *c1 = src + 1 * 16 * (sizeof (int));
775+
__private char *c2 = src + 2 * 16 * (sizeof (int));
776+
__private char *c3 = src + 3 * 16 * (sizeof (int));
777+
__private char *c4 = src + 4 * 16 * (sizeof (int));
778+
__private char *c5 = src + 5 * 16 * (sizeof (int));
779+
__private char *c6 = src + 6 * 16 * (sizeof (int));
780+
__private char *c7 = src + 7 * 16 * (sizeof (int));
781+
782+
char *mem0 = mem + 0 * 16 * (sizeof (int));
783+
char *mem1 = mem + 1 * 16 * (sizeof (int));
784+
char *mem2 = mem + 2 * 16 * (sizeof (int));
785+
char *mem3 = mem + 3 * 16 * (sizeof (int));
786+
char *mem4 = mem + 0 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride;
787+
char *mem5 = mem + 1 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride;
788+
char *mem6 = mem + 2 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride;
789+
char *mem7 = mem + 3 * 16 * (sizeof (int)) + 16 * (sizeof (int)) * stride;
790+
791+
__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8(mem0, c0, stride);
792+
__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8(mem1, c1, stride);
793+
__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8(mem2, c2, stride);
794+
__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8(mem3, c3, stride);
795+
__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8(mem4, c4, stride);
796+
__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8(mem5, c5, stride);
797+
__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8(mem6, c6, stride);
798+
__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8(mem7, c7, stride);
799+
}
800+

0 commit comments

Comments
 (0)