@@ -118,6 +118,7 @@ extern __constant int __JointMatrixLoadStoreOpt;
118118#define ARR_TO_VEC1 (type , arr ) \
119119 arr[0]
120120
121+ #define OUT_VEC16 (type ) type##16
121122#define OUT_VEC8 (type ) type##8
122123#define OUT_VEC7 (type ) type##8
123124#define OUT_VEC6 (type ) type##8
@@ -610,3 +611,190 @@ DEFINE_GET_COORD(Accumulator, , 32, 8, 8, 8x8, 8, 1)
610611//bfloat16
611612DEFINE_GET_COORD (PackedA , , 16 , 8 , 16 , 8 x16 , 8 , 1 )
612613DEFINE_GET_COORD (PackedB , , 16 , 16 , 8 , 16 x8 , 8 , 2 )
614+
615+ /* experimental large slice support: */
616+
617+ INLINE void __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32 (__private char * a_ptr , __private char * b_ptr , __private char * raw_c_ptr , __private char * result ) {
618+ short16 a = * (short16 * )a_ptr ;
619+ int8 b = * (int8 * )b_ptr ;
620+ int16 raw_c = * (int16 * )raw_c_ptr ;
621+
622+ short8 a0 = (short8 )(a .s0 , a .s1 , a .s2 , a .s3 , a .s4 , a .s5 , a .s6 , a .s7 );
623+ short8 a1 = (short8 )(a .s8 , a .s9 , a .sa , a .sb , a .sc , a .sd , a .se , a .sf );
624+
625+ float16 c = * (float16 * )& raw_c ;
626+
627+ float8 c0 = (float8 )(c .s0 , c .s1 , c .s2 , c .s3 , c .s4 , c .s5 , c .s6 , c .s7 );
628+ float8 c1 = (float8 )(c .s8 , c .s9 , c .sa , c .sb , c .sc , c .sd , c .se , c .sf );
629+
630+ float8 fres0 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_8 (c0 , a0 , b );
631+ float8 fres1 = __builtin_IB_sub_group16_fdpas_f_f_bf_bf_8_8 (c0 , a1 , b );
632+
633+ int8 res0 = * (int8 * )& fres0 ;
634+ int8 res1 = * (int8 * )& fres1 ;
635+
636+ __private int16 * dst = (__private int16 * )result ;
637+ * dst = (int16 )(res0 .s0 , res0 .s1 , res0 .s2 , res0 .s3 , res0 .s4 , res0 .s5 , res0 .s6 , res0 .s7 ,
638+ res1 .s0 , res1 .s1 , res1 .s2 , res1 .s3 , res1 .s4 , res1 .s5 , res1 .s6 , res1 .s7 );
639+ }
640+
641+ INLINE void __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_fp16_fp16_fp32 (__private char * a_ptr , __private char * b_ptr , __private char * raw_c_ptr , __private char * result ) {
642+ short16 a = * (short16 * )a_ptr ;
643+ int8 b = * (int8 * )b_ptr ;
644+ int16 raw_c = * (int16 * )raw_c_ptr ;
645+
646+ short8 a0 = (short8 )(a .s0 , a .s1 , a .s2 , a .s3 , a .s4 , a .s5 , a .s6 , a .s7 );
647+ short8 a1 = (short8 )(a .s8 , a .s9 , a .sa , a .sb , a .sc , a .sd , a .se , a .sf );
648+
649+ float16 c = * (float16 * )& raw_c ;
650+
651+ float8 c0 = (float8 )(c .s0 , c .s1 , c .s2 , c .s3 , c .s4 , c .s5 , c .s6 , c .s7 );
652+ float8 c1 = (float8 )(c .s8 , c .s9 , c .sa , c .sb , c .sc , c .sd , c .se , c .sf );
653+
654+ float8 fres0 = __builtin_IB_sub_group16_fdpas_f_f_hf_hf_8_8 (c0 , a0 , b );
655+ float8 fres1 = __builtin_IB_sub_group16_fdpas_f_f_hf_hf_8_8 (c0 , a1 , b );
656+
657+ int8 res0 = * (int8 * )& fres0 ;
658+ int8 res1 = * (int8 * )& fres1 ;
659+
660+ __private int16 * dst = (__private int16 * )result ;
661+ * dst = (int16 )(res0 .s0 , res0 .s1 , res0 .s2 , res0 .s3 , res0 .s4 , res0 .s5 , res0 .s6 , res0 .s7 ,
662+ res1 .s0 , res1 .s1 , res1 .s2 , res1 .s3 , res1 .s4 , res1 .s5 , res1 .s6 , res1 .s7 );
663+ }
664+
665+ INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x64x16_bf16_bf16_fp32 (__private char * a_ptr , __private char * b_ptr , __private char * c_ptr , __private char * d_ptr ) {
666+ __private char * a0 = a_ptr ;
667+ __private char * a1 = a_ptr + 16 * (sizeof (short ));
668+
669+ __private char * b0 = b_ptr ;
670+ __private char * b1 = b_ptr + 1 * 16 * (sizeof (short ));
671+ __private char * b2 = b_ptr + 2 * 16 * (sizeof (short ));
672+ __private char * b3 = b_ptr + 3 * 16 * (sizeof (short ));
673+
674+ __private char * c0 = c_ptr + 0 * 16 * (sizeof (int ));
675+ __private char * c1 = c_ptr + 1 * 16 * (sizeof (int ));
676+ __private char * c2 = c_ptr + 2 * 16 * (sizeof (int ));
677+ __private char * c3 = c_ptr + 3 * 16 * (sizeof (int ));
678+ __private char * c4 = c_ptr + 4 * 16 * (sizeof (int ));
679+ __private char * c5 = c_ptr + 5 * 16 * (sizeof (int ));
680+ __private char * c6 = c_ptr + 6 * 16 * (sizeof (int ));
681+ __private char * c7 = c_ptr + 7 * 16 * (sizeof (int ));
682+
683+ __private char * d0 = d_ptr + 0 * 16 * (sizeof (int ));
684+ __private char * d1 = d_ptr + 1 * 16 * (sizeof (int ));
685+ __private char * d2 = d_ptr + 2 * 16 * (sizeof (int ));
686+ __private char * d3 = d_ptr + 3 * 16 * (sizeof (int ));
687+ __private char * d4 = d_ptr + 4 * 16 * (sizeof (int ));
688+ __private char * d5 = d_ptr + 5 * 16 * (sizeof (int ));
689+ __private char * d6 = d_ptr + 6 * 16 * (sizeof (int ));
690+ __private char * d7 = d_ptr + 7 * 16 * (sizeof (int ));
691+
692+ __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32 (a0 , b0 , c0 , d0 );
693+ __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32 (a0 , b1 , c1 , d1 );
694+ __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32 (a0 , b2 , c2 , d2 );
695+ __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32 (a0 , b3 , c3 , d3 );
696+
697+ __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32 (a1 , b0 , c4 , d4 );
698+ __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32 (a1 , b1 , c5 , d5 );
699+ __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32 (a1 , b2 , c6 , d6 );
700+ __builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32 (a1 , b3 , c7 , d7 );
701+ }
702+
703+ INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32 (__private char * dst , char * mem , long stride ) {
704+ IMPLEMENT_BLOCK2D_LOAD_SG16_ROW_MAJOR (int , 32 , int , 32 , 16 , 16 , 16 )
705+ }
706+
707+ INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_16x16_i16_generic_v8i8_pi32_i32 (__private char * dst , char * mem , long stride ) {
708+ IMPLEMENT_BLOCK2D_LOAD_SG16_ROW_MAJOR (short , 16 , short , 16 , 16 , 16 , 16 )
709+ }
710+
711+ INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x64_i32_generic_v8i8_pi32_i32 (__private char * dst , char * mem , long stride ) {
712+ __private char * c0 = dst + 0 * 16 * (sizeof (int ));
713+ __private char * c1 = dst + 1 * 16 * (sizeof (int ));
714+ __private char * c2 = dst + 2 * 16 * (sizeof (int ));
715+ __private char * c3 = dst + 3 * 16 * (sizeof (int ));
716+ __private char * c4 = dst + 4 * 16 * (sizeof (int ));
717+ __private char * c5 = dst + 5 * 16 * (sizeof (int ));
718+ __private char * c6 = dst + 6 * 16 * (sizeof (int ));
719+ __private char * c7 = dst + 7 * 16 * (sizeof (int ));
720+
721+ char * mem0 = mem + 0 * 16 * (sizeof (int ));
722+ char * mem1 = mem + 1 * 16 * (sizeof (int ));
723+ char * mem2 = mem + 2 * 16 * (sizeof (int ));
724+ char * mem3 = mem + 3 * 16 * (sizeof (int ));
725+ char * mem4 = mem + 0 * 16 * (sizeof (int )) + 16 * (sizeof (int )) * stride ;
726+ char * mem5 = mem + 1 * 16 * (sizeof (int )) + 16 * (sizeof (int )) * stride ;
727+ char * mem6 = mem + 2 * 16 * (sizeof (int )) + 16 * (sizeof (int )) * stride ;
728+ char * mem7 = mem + 3 * 16 * (sizeof (int )) + 16 * (sizeof (int )) * stride ;
729+
730+ __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32 (c0 , mem0 , stride );
731+ __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32 (c1 , mem1 , stride );
732+ __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32 (c2 , mem2 , stride );
733+ __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32 (c3 , mem3 , stride );
734+ __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32 (c4 , mem4 , stride );
735+ __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32 (c5 , mem5 , stride );
736+ __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32 (c6 , mem6 , stride );
737+ __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_v8i8_pi32_i32 (c7 , mem7 , stride );
738+ }
739+
740+ INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_32x16_i16_generic_v8i8_pi32_i32 (__private char * dst , char * mem , long stride ) {
741+ __private char * dst0 = dst ;
742+ __private char * dst1 = dst + 16 * (sizeof (short ));
743+
744+ char * mem0 = mem ;
745+ char * mem1 = mem + 16 * (sizeof (short )) * stride ;
746+
747+ __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_16x16_i16_generic_v8i8_pi32_i32 (dst0 , mem0 , stride );
748+ __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_16x16_i16_generic_v8i8_pi32_i32 (dst1 , mem1 , stride );
749+ }
750+
751+ INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x64_i16_generic_v8i8_pi32_i32 (__private char * dst , char * mem , long stride ) {
752+ __private char * b0 = dst ;
753+ __private char * b1 = dst + 1 * 16 * (sizeof (short ));
754+ __private char * b2 = dst + 2 * 16 * (sizeof (short ));
755+ __private char * b3 = dst + 3 * 16 * (sizeof (short ));
756+
757+ char * mem0 = mem + 0 * 16 * (sizeof (int ));
758+ char * mem1 = mem + 1 * 16 * (sizeof (int ));
759+ char * mem2 = mem + 2 * 16 * (sizeof (int ));
760+ char * mem3 = mem + 3 * 16 * (sizeof (int ));
761+
762+ __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_generic_v8i8_pi32_i32 (b0 , mem0 , stride );
763+ __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_generic_v8i8_pi32_i32 (b1 , mem1 , stride );
764+ __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_generic_v8i8_pi32_i32 (b2 , mem2 , stride );
765+ __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x16_i16_generic_v8i8_pi32_i32 (b3 , mem3 , stride );
766+ }
767+
768+ INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8 (char * mem , __private char * src , long stride ) {
769+ IMPLEMENT_BLOCK2D_STORE_SG16 (int , int , 32 , 16 , 16 , slice )
770+ }
771+
772+ INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_32x64_i32_generic_pi64_v8i8 (char * mem , __private char * src , long stride ) {
773+ __private char * c0 = src + 0 * 16 * (sizeof (int ));
774+ __private char * c1 = src + 1 * 16 * (sizeof (int ));
775+ __private char * c2 = src + 2 * 16 * (sizeof (int ));
776+ __private char * c3 = src + 3 * 16 * (sizeof (int ));
777+ __private char * c4 = src + 4 * 16 * (sizeof (int ));
778+ __private char * c5 = src + 5 * 16 * (sizeof (int ));
779+ __private char * c6 = src + 6 * 16 * (sizeof (int ));
780+ __private char * c7 = src + 7 * 16 * (sizeof (int ));
781+
782+ char * mem0 = mem + 0 * 16 * (sizeof (int ));
783+ char * mem1 = mem + 1 * 16 * (sizeof (int ));
784+ char * mem2 = mem + 2 * 16 * (sizeof (int ));
785+ char * mem3 = mem + 3 * 16 * (sizeof (int ));
786+ char * mem4 = mem + 0 * 16 * (sizeof (int )) + 16 * (sizeof (int )) * stride ;
787+ char * mem5 = mem + 1 * 16 * (sizeof (int )) + 16 * (sizeof (int )) * stride ;
788+ char * mem6 = mem + 2 * 16 * (sizeof (int )) + 16 * (sizeof (int )) * stride ;
789+ char * mem7 = mem + 3 * 16 * (sizeof (int )) + 16 * (sizeof (int )) * stride ;
790+
791+ __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8 (mem0 , c0 , stride );
792+ __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8 (mem1 , c1 , stride );
793+ __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8 (mem2 , c2 , stride );
794+ __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8 (mem3 , c3 , stride );
795+ __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8 (mem4 , c4 , stride );
796+ __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8 (mem5 , c5 , stride );
797+ __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8 (mem6 , c6 , stride );
798+ __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_16x16_i32_generic_pi64_v8i8 (mem7 , c7 , stride );
799+ }
800+
0 commit comments