@@ -382,7 +382,7 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
382382{
383383 BLASLONG i ,j ,k ;
384384 FLOAT * C0 ,* C1 ,* C2 ,* C3 ;
385- FLOAT * ptrba ,* ptrbb ;
385+ FLOAT * ptrba ,* ptrbb , * tmpc ;
386386
387387 FLOAT loadb0 ,loadb1 ,loadb2 ,loadb3 ;
388388 FLOAT load0 ,load1 ,load2 ,load3 ,load4 ,load5 ,load6 ,load7 ;
@@ -392,6 +392,7 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
392392 FLOAT res8 ,res9 ,res10 ,res11 ;
393393 FLOAT res12 ,res13 ,res14 ,res15 ;
394394
395+
395396 for (j = 0 ; j < bn /4 ; j += 1 ){
396397 C0 = C ;
397398 C1 = C0 + ldc ;
@@ -942,53 +943,109 @@ int CNAME(BLASLONG bm,BLASLONG bn,BLASLONG bk,FLOAT alpha,FLOAT* ba,FLOAT* bb,FL
942943 }
943944 if (bm & 1 ){
944945 ptrbb = bb ;
945-
946- res0 = 0 ;
947-
948- res4 = 0 ;
949-
950- res8 = 0 ;
951-
952- res12 = 0 ;
953-
954- for (k = 0 ; k < bk ; k += 1 ){
955- loadb0 = ptrbb [0 ];
956- loadb1 = ptrbb [1 ];
957-
958- load0 = ptrba [0 ];
959-
960- res0 = res0 + load0 * loadb0 ;
961-
962- res4 = res4 + load0 * loadb1 ;
963-
964- loadb2 = ptrbb [2 ];
965- loadb3 = ptrbb [3 ];
966-
967- res8 = res8 + load0 * loadb2 ;
968-
969- res12 = res12 + load0 * loadb3 ;
970-
971- ptrba += 1 ;
972- ptrbb += 4 ;
973- }
974-
975- res0 = res0 * alpha ;
976-
977- res4 = res4 * alpha ;
946+ //t0 for k
947+ //ft0-ft3,ft4-ft7,v8-v15 for B, t1-t3 for PB1-3
948+ //v0-v3,v4-v7 for A, t4-t6 for PA1-3
949+ //v16-v31 for temp C
978950
979- res8 = res8 * alpha ;
951+ FLOAT tmp [4 ];
952+ tmpc = tmp ;
953+ //t1-t3 for PB
954+ //v0-v4 for A, v8-v11 for B
955+ //v16-v19 for C
956+ asm volatile (
957+ "vsetvli zero, zero, e32,m1 \n\t"
958+ "fmv.w.x ft11, zero \n\t"
959+
960+ "vfmv.v.f v16, ft11 \n\t"
961+ "vfmv.v.f v17, ft11 \n\t"
962+ "vfmv.v.f v18, ft11 \n\t"
963+ "vfmv.v.f v19, ft11 \n\t"
964+ //unloop 4
980965
981- res12 = res12 * alpha ;
966+ "srli t0, %[BK], 2 \n\t"
967+ "blez t0, M1x4_TAIL \n\t"
982968
983- C0 [0 ] += res0 ;
984- C1 [0 ] += res4 ;
985- C2 [0 ] += res8 ;
986- C3 [0 ] += res12 ;
969+ "addi t1, %[PB], 4*4 \n\t"
970+ "addi t2, %[PB], 8*4 \n\t"
971+ "addi t3, %[PB], 12*4 \n\t"
972+
973+ ".align 4 \n\t"
974+ "M1x4_MAINLOOP: \n\t"
987975
976+ "vle.v v4, (%[PA]) \n\t"
977+ "addi %[PA], %[PA], 4*4 \n\t"
978+ "vrgather.vi v0, v4, 0 \n\t"
979+
980+ "vle.v v8, (%[PB]) \n\t"
981+ "addi %[PB], %[PB], 16*4 \n\t"
982+ "vrgather.vi v1, v4, 1 \n\t"
983+
984+ "vle.v v9, (t1) \n\t"
985+ "addi t1, t1, 16*4 \n\t"
986+ "vrgather.vi v2, v4, 2 \n\t"
987+
988+ "vle.v v10, (t2) \n\t"
989+ "addi t2, t2, 16*4 \n\t"
990+ "vrgather.vi v3, v4, 3 \n\t"
991+
992+ "vle.v v11, (t3) \n\t"
993+ "addi t3, t3, 16*4 \n\t"
994+
995+ "vfmacc.vv v16, v8, v0 \n\t"
996+ "vfmacc.vv v17, v9, v1 \n\t"
997+ "vfmacc.vv v18, v10, v2 \n\t"
998+ "vfmacc.vv v19, v11, v3 \n\t"
999+
1000+ "addi t0, t0, -1 \n\t"
1001+ "bgtz t0, M1x4_MAINLOOP \n\t"
1002+
1003+ "M1x4_TAIL: \n\t"
1004+ "andi t0, %[BK], 3 \n\t"
1005+ "blez t0, M1x4_SAVERESULT \n\t"
1006+
1007+ "M1x4_TAILLOOP: \n\t"
1008+ "flw ft0, (%[PA]) \n\t"
1009+ "addi %[PA], %[PA], 1*4 \n\t"
1010+ "vle.v v8, (%[PB]) \n\t"
1011+ "addi %[PB], %[PB], 4*4 \n\t"
1012+ "vfmv.v.f v0, ft0 \n\t"
1013+ "vfmacc.vv v16, v8, v0 \n\t"
1014+
1015+ "addi t0, t0, -1 \n\t"
1016+ "bgtz t0, M1x4_TAILLOOP \n\t"
1017+
1018+ "M1x4_SAVERESULT: \n\t"
1019+ //merge v16-v19
1020+ "vfadd.vv v16, v16, v17 \n\t"
1021+ "vfadd.vv v18, v18, v19 \n\t"
1022+ "vfadd.vv v16, v16, v18 \n\t"
1023+
1024+ "vfmv.v.f v8, %[ALPHA] \n\t"
1025+ "vfmul.vv v16, v8, v16 \n\t"
1026+ "vse.v v16, (%[TMP_C]) \n\t"
1027+ "M1x4_END: \n\t"
1028+ :[TMP_C ]"+r" (tmpc ),
1029+ [PA ]"+r" (ptrba ), [PB ]"+r" (ptrbb )
1030+ :[ALPHA ]"f" (alpha ), [BK ]"r" (bk )
1031+ :"cc" , "t0" , "t3" ,"t1" ,"t2" ,
1032+ "ft0" , "ft11" ,
1033+ "v0" , "v1" , "v2" , "v3" ,"v4" ,
1034+ "v8" , "v9" , "v10" , "v11" ,
1035+ "v16" , "v17" ,"v18" , "v19"
1036+ );
1037+
1038+ C0 [0 ] += tmp [0 ];
1039+ C1 [0 ] += tmp [1 ];
1040+ C2 [0 ] += tmp [2 ];
1041+ C3 [0 ] += tmp [3 ];
1042+
1043+ /* don't need move c point
9881044 C0 += 1;
9891045 C1 += 1;
9901046 C2 += 1;
9911047 C3 += 1;
1048+ */
9921049 }
9931050
9941051 k = bk <<2 ;
0 commit comments