Skip to content

Commit 22b7950

Browse files
committed
Use LMUL2 for calculations in main block - just break them apart before last stage.
1 parent 3b1aef1 commit 22b7950

2 files changed

Lines changed: 68 additions & 74 deletions

File tree

kernel/riscv64/dgemm_kernel_8x8_zvl256b.c

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,27 +1654,16 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT* A, FLOAT* B, F
16541654
B += 8;
16551655

16561656
vfloat64m2_t A00 = __riscv_vle64_v_f64m2( A, 8 );
1657-
vfloat64m1_t A0 = __riscv_vget_v_f64m2_f64m1(A00, 0);
1658-
vfloat64m1_t A1 = __riscv_vget_v_f64m2_f64m1(A00, 1);
16591657
A += 8;
16601658

1661-
// LMUL = 2 does worst here
1662-
vfloat64m1_t result0 = __riscv_vfmul_vf_f64m1( A0, B0, 4 );
1663-
vfloat64m1_t result1 = __riscv_vfmul_vf_f64m1( A1, B0, 4 );
1664-
vfloat64m1_t result2 = __riscv_vfmul_vf_f64m1( A0, B1, 4 );
1665-
vfloat64m1_t result3 = __riscv_vfmul_vf_f64m1( A1, B1, 4 );
1666-
vfloat64m1_t result4 = __riscv_vfmul_vf_f64m1( A0, B2, 4 );
1667-
vfloat64m1_t result5 = __riscv_vfmul_vf_f64m1( A1, B2, 4 );
1668-
vfloat64m1_t result6 = __riscv_vfmul_vf_f64m1( A0, B3, 4 );
1669-
vfloat64m1_t result7 = __riscv_vfmul_vf_f64m1( A1, B3, 4 );
1670-
vfloat64m1_t result8 = __riscv_vfmul_vf_f64m1( A0, B4, 4 );
1671-
vfloat64m1_t result9 = __riscv_vfmul_vf_f64m1( A1, B4, 4 );
1672-
vfloat64m1_t result10 = __riscv_vfmul_vf_f64m1( A0, B5, 4 );
1673-
vfloat64m1_t result11 = __riscv_vfmul_vf_f64m1( A1, B5, 4 );
1674-
vfloat64m1_t result12 = __riscv_vfmul_vf_f64m1( A0, B6, 4 );
1675-
vfloat64m1_t result13 = __riscv_vfmul_vf_f64m1( A1, B6, 4 );
1676-
vfloat64m1_t result14 = __riscv_vfmul_vf_f64m1( A0, B7, 4 );
1677-
vfloat64m1_t result15 = __riscv_vfmul_vf_f64m1( A1, B7, 4 );
1659+
vfloat64m2_t result01 = __riscv_vfmul_vf_f64m2( A00, B0, 8 );
1660+
vfloat64m2_t result23 = __riscv_vfmul_vf_f64m2( A00, B1, 8 );
1661+
vfloat64m2_t result45 = __riscv_vfmul_vf_f64m2( A00, B2, 8 );
1662+
vfloat64m2_t result67 = __riscv_vfmul_vf_f64m2( A00, B3, 8 );
1663+
vfloat64m2_t result89 = __riscv_vfmul_vf_f64m2( A00, B4, 8 );
1664+
vfloat64m2_t resultAB = __riscv_vfmul_vf_f64m2( A00, B5, 8 );
1665+
vfloat64m2_t resultCD = __riscv_vfmul_vf_f64m2( A00, B6, 8 );
1666+
vfloat64m2_t resultEF = __riscv_vfmul_vf_f64m2( A00, B7, 8 );
16781667

16791668
for (BLASLONG k = K; --k; ) {
16801669
B0 = B[0];
@@ -1688,28 +1677,36 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT* A, FLOAT* B, F
16881677
B += 8;
16891678

16901679
A00 = __riscv_vle64_v_f64m2( A, 8 );
1691-
A0 = __riscv_vget_v_f64m2_f64m1(A00, 0);
1692-
A1 = __riscv_vget_v_f64m2_f64m1(A00, 1);
16931680
A += 8;
16941681

1695-
result0 = __riscv_vfmacc_vf_f64m1( result0, B0, A0, 4 );
1696-
result1 = __riscv_vfmacc_vf_f64m1( result1, B0, A1, 4 );
1697-
result2 = __riscv_vfmacc_vf_f64m1( result2, B1, A0, 4 );
1698-
result3 = __riscv_vfmacc_vf_f64m1( result3, B1, A1, 4 );
1699-
result4 = __riscv_vfmacc_vf_f64m1( result4, B2, A0, 4 );
1700-
result5 = __riscv_vfmacc_vf_f64m1( result5, B2, A1, 4 );
1701-
result6 = __riscv_vfmacc_vf_f64m1( result6, B3, A0, 4 );
1702-
result7 = __riscv_vfmacc_vf_f64m1( result7, B3, A1, 4 );
1703-
result8 = __riscv_vfmacc_vf_f64m1( result8, B4, A0, 4 );
1704-
result9 = __riscv_vfmacc_vf_f64m1( result9, B4, A1, 4 );
1705-
result10 = __riscv_vfmacc_vf_f64m1( result10, B5, A0, 4 );
1706-
result11 = __riscv_vfmacc_vf_f64m1( result11, B5, A1, 4 );
1707-
result12 = __riscv_vfmacc_vf_f64m1( result12, B6, A0, 4 );
1708-
result13 = __riscv_vfmacc_vf_f64m1( result13, B6, A1, 4 );
1709-
result14 = __riscv_vfmacc_vf_f64m1( result14, B7, A0, 4 );
1710-
result15 = __riscv_vfmacc_vf_f64m1( result15, B7, A1, 4 );
1682+
result01 = __riscv_vfmacc_vf_f64m2( result01, B0, A00, 8 );
1683+
result23 = __riscv_vfmacc_vf_f64m2( result23, B1, A00, 8 );
1684+
result45 = __riscv_vfmacc_vf_f64m2( result45, B2, A00, 8 );
1685+
result67 = __riscv_vfmacc_vf_f64m2( result67, B3, A00, 8 );
1686+
result89 = __riscv_vfmacc_vf_f64m2( result89, B4, A00, 8 );
1687+
resultAB = __riscv_vfmacc_vf_f64m2( resultAB, B5, A00, 8 );
1688+
resultCD = __riscv_vfmacc_vf_f64m2( resultCD, B6, A00, 8 );
1689+
resultEF = __riscv_vfmacc_vf_f64m2( resultEF, B7, A00, 8 );
17111690
}
17121691

1692+
// LMUL = 2 does worst here
1693+
vfloat64m1_t result0 = __riscv_vget_v_f64m2_f64m1(result01, 0);
1694+
vfloat64m1_t result1 = __riscv_vget_v_f64m2_f64m1(result01, 1);
1695+
vfloat64m1_t result2 = __riscv_vget_v_f64m2_f64m1(result23, 0);
1696+
vfloat64m1_t result3 = __riscv_vget_v_f64m2_f64m1(result23, 1);
1697+
vfloat64m1_t result4 = __riscv_vget_v_f64m2_f64m1(result45, 0);
1698+
vfloat64m1_t result5 = __riscv_vget_v_f64m2_f64m1(result45, 1);
1699+
vfloat64m1_t result6 = __riscv_vget_v_f64m2_f64m1(result67, 0);
1700+
vfloat64m1_t result7 = __riscv_vget_v_f64m2_f64m1(result67, 1);
1701+
vfloat64m1_t result8 = __riscv_vget_v_f64m2_f64m1(result89, 0);
1702+
vfloat64m1_t result9 = __riscv_vget_v_f64m2_f64m1(result89, 1);
1703+
vfloat64m1_t result10 = __riscv_vget_v_f64m2_f64m1(resultAB, 0);
1704+
vfloat64m1_t result11 = __riscv_vget_v_f64m2_f64m1(resultAB, 1);
1705+
vfloat64m1_t result12 = __riscv_vget_v_f64m2_f64m1(resultCD, 0);
1706+
vfloat64m1_t result13 = __riscv_vget_v_f64m2_f64m1(resultCD, 1);
1707+
vfloat64m1_t result14 = __riscv_vget_v_f64m2_f64m1(resultEF, 0);
1708+
vfloat64m1_t result15 = __riscv_vget_v_f64m2_f64m1(resultEF, 1);
1709+
17131710
FLOAT *C2 = C;
17141711

17151712
vfloat64m2_t c00;

kernel/riscv64/sgemm_kernel_16x8_zvl256b.c

Lines changed: 34 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2159,27 +2159,16 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT* A, FLOAT* B, F
21592159
B += 8;
21602160

21612161
vfloat32m2_t A00 = __riscv_vle32_v_f32m2( A, 16 );
2162-
vfloat32m1_t A0 = __riscv_vget_v_f32m2_f32m1(A00, 0);
2163-
vfloat32m1_t A1 = __riscv_vget_v_f32m2_f32m1(A00, 1);
21642162
A += 16;
21652163

2166-
// LMUL = 2 does worst here
2167-
vfloat32m1_t result0 = __riscv_vfmul_vf_f32m1( A0, B0, 8 );
2168-
vfloat32m1_t result1 = __riscv_vfmul_vf_f32m1( A1, B0, 8 );
2169-
vfloat32m1_t result2 = __riscv_vfmul_vf_f32m1( A0, B1, 8 );
2170-
vfloat32m1_t result3 = __riscv_vfmul_vf_f32m1( A1, B1, 8 );
2171-
vfloat32m1_t result4 = __riscv_vfmul_vf_f32m1( A0, B2, 8 );
2172-
vfloat32m1_t result5 = __riscv_vfmul_vf_f32m1( A1, B2, 8 );
2173-
vfloat32m1_t result6 = __riscv_vfmul_vf_f32m1( A0, B3, 8 );
2174-
vfloat32m1_t result7 = __riscv_vfmul_vf_f32m1( A1, B3, 8 );
2175-
vfloat32m1_t result8 = __riscv_vfmul_vf_f32m1( A0, B4, 8 );
2176-
vfloat32m1_t result9 = __riscv_vfmul_vf_f32m1( A1, B4, 8 );
2177-
vfloat32m1_t result10 = __riscv_vfmul_vf_f32m1( A0, B5, 8 );
2178-
vfloat32m1_t result11 = __riscv_vfmul_vf_f32m1( A1, B5, 8 );
2179-
vfloat32m1_t result12 = __riscv_vfmul_vf_f32m1( A0, B6, 8 );
2180-
vfloat32m1_t result13 = __riscv_vfmul_vf_f32m1( A1, B6, 8 );
2181-
vfloat32m1_t result14 = __riscv_vfmul_vf_f32m1( A0, B7, 8 );
2182-
vfloat32m1_t result15 = __riscv_vfmul_vf_f32m1( A1, B7, 8 );
2164+
vfloat32m2_t result01 = __riscv_vfmul_vf_f32m2( A00, B0, 16 );
2165+
vfloat32m2_t result23 = __riscv_vfmul_vf_f32m2( A00, B1, 16 );
2166+
vfloat32m2_t result45 = __riscv_vfmul_vf_f32m2( A00, B2, 16 );
2167+
vfloat32m2_t result67 = __riscv_vfmul_vf_f32m2( A00, B3, 16 );
2168+
vfloat32m2_t result89 = __riscv_vfmul_vf_f32m2( A00, B4, 16 );
2169+
vfloat32m2_t resultAB = __riscv_vfmul_vf_f32m2( A00, B5, 16 );
2170+
vfloat32m2_t resultCD = __riscv_vfmul_vf_f32m2( A00, B6, 16 );
2171+
vfloat32m2_t resultEF = __riscv_vfmul_vf_f32m2( A00, B7, 16 );
21832172

21842173
for (BLASLONG k = K; --k; ) {
21852174
B0 = B[0];
@@ -2193,28 +2182,36 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT* A, FLOAT* B, F
21932182
B += 8;
21942183

21952184
A00 = __riscv_vle32_v_f32m2( A, 16 );
2196-
A0 = __riscv_vget_v_f32m2_f32m1(A00, 0);
2197-
A1 = __riscv_vget_v_f32m2_f32m1(A00, 1);
21982185
A += 16;
21992186

2200-
result0 = __riscv_vfmacc_vf_f32m1( result0, B0, A0, 8 );
2201-
result1 = __riscv_vfmacc_vf_f32m1( result1, B0, A1, 8 );
2202-
result2 = __riscv_vfmacc_vf_f32m1( result2, B1, A0, 8 );
2203-
result3 = __riscv_vfmacc_vf_f32m1( result3, B1, A1, 8 );
2204-
result4 = __riscv_vfmacc_vf_f32m1( result4, B2, A0, 8 );
2205-
result5 = __riscv_vfmacc_vf_f32m1( result5, B2, A1, 8 );
2206-
result6 = __riscv_vfmacc_vf_f32m1( result6, B3, A0, 8 );
2207-
result7 = __riscv_vfmacc_vf_f32m1( result7, B3, A1, 8 );
2208-
result8 = __riscv_vfmacc_vf_f32m1( result8, B4, A0, 8 );
2209-
result9 = __riscv_vfmacc_vf_f32m1( result9, B4, A1, 8 );
2210-
result10 = __riscv_vfmacc_vf_f32m1( result10, B5, A0, 8 );
2211-
result11 = __riscv_vfmacc_vf_f32m1( result11, B5, A1, 8 );
2212-
result12 = __riscv_vfmacc_vf_f32m1( result12, B6, A0, 8 );
2213-
result13 = __riscv_vfmacc_vf_f32m1( result13, B6, A1, 8 );
2214-
result14 = __riscv_vfmacc_vf_f32m1( result14, B7, A0, 8 );
2215-
result15 = __riscv_vfmacc_vf_f32m1( result15, B7, A1, 8 );
2187+
result01 = __riscv_vfmacc_vf_f32m2( result01, B0, A00, 16 );
2188+
result23 = __riscv_vfmacc_vf_f32m2( result23, B1, A00, 16 );
2189+
result45 = __riscv_vfmacc_vf_f32m2( result45, B2, A00, 16 );
2190+
result67 = __riscv_vfmacc_vf_f32m2( result67, B3, A00, 16 );
2191+
result89 = __riscv_vfmacc_vf_f32m2( result89, B4, A00, 16 );
2192+
resultAB = __riscv_vfmacc_vf_f32m2( resultAB, B5, A00, 16 );
2193+
resultCD = __riscv_vfmacc_vf_f32m2( resultCD, B6, A00, 16 );
2194+
resultEF = __riscv_vfmacc_vf_f32m2( resultEF, B7, A00, 16 );
22162195
}
22172196

2197+
// LMUL = 2 does worst here
2198+
vfloat32m1_t result0 = __riscv_vget_v_f32m2_f32m1(result01, 0);
2199+
vfloat32m1_t result1 = __riscv_vget_v_f32m2_f32m1(result01, 1);
2200+
vfloat32m1_t result2 = __riscv_vget_v_f32m2_f32m1(result23, 0);
2201+
vfloat32m1_t result3 = __riscv_vget_v_f32m2_f32m1(result23, 1);
2202+
vfloat32m1_t result4 = __riscv_vget_v_f32m2_f32m1(result45, 0);
2203+
vfloat32m1_t result5 = __riscv_vget_v_f32m2_f32m1(result45, 1);
2204+
vfloat32m1_t result6 = __riscv_vget_v_f32m2_f32m1(result67, 0);
2205+
vfloat32m1_t result7 = __riscv_vget_v_f32m2_f32m1(result67, 1);
2206+
vfloat32m1_t result8 = __riscv_vget_v_f32m2_f32m1(result89, 0);
2207+
vfloat32m1_t result9 = __riscv_vget_v_f32m2_f32m1(result89, 1);
2208+
vfloat32m1_t result10 = __riscv_vget_v_f32m2_f32m1(resultAB, 0);
2209+
vfloat32m1_t result11 = __riscv_vget_v_f32m2_f32m1(resultAB, 1);
2210+
vfloat32m1_t result12 = __riscv_vget_v_f32m2_f32m1(resultCD, 0);
2211+
vfloat32m1_t result13 = __riscv_vget_v_f32m2_f32m1(resultCD, 1);
2212+
vfloat32m1_t result14 = __riscv_vget_v_f32m2_f32m1(resultEF, 0);
2213+
vfloat32m1_t result15 = __riscv_vget_v_f32m2_f32m1(resultEF, 1);
2214+
22182215
FLOAT *C2 = C;
22192216

22202217
vfloat32m2_t c00;

0 commit comments

Comments
 (0)