Skip to content

Commit 477dd40

Browse files
committed
Simplier loops.
1 parent 79d9fe3 commit 477dd40

1 file changed

Lines changed: 47 additions & 54 deletions

File tree

kernel/riscv64/sgemm_kernel_16x8_zvl256b.c

Lines changed: 47 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ static void FORCEINLINE M_TAIL_ONE(BLASLONG K, const BLASLONG M, const BLASLONG
101101
A3 += (1 * 8);
102102
#endif
103103

104-
for (BLASLONG k = (K / 8); --k; ) {
104+
BLASLONG k = (K / 8);
105+
K &= 7;
106+
while (--k) {
105107
B00 = __riscv_vle32_v_f32m8(B, N * 8);
106108
B0 = __riscv_vget_v_f32m8_f32m1(B00, 0);
107109
B1 = __riscv_vget_v_f32m8_f32m1(B00, 1);
@@ -143,8 +145,6 @@ static void FORCEINLINE M_TAIL_ONE(BLASLONG K, const BLASLONG M, const BLASLONG
143145
resultE = __riscv_vfadd_vv_f32m1(resultE, result1, N);
144146
result3 = __riscv_vfadd_vv_f32m1(result3, result5, N);
145147
resultE = __riscv_vfadd_vv_f32m1(resultE, result3, N);
146-
147-
K &= 7;
148148
} else {
149149
resultE = __riscv_vreinterpret_v_u32m1_f32m1(__riscv_vmv_v_x_u32m1(0, N));
150150
}
@@ -196,7 +196,9 @@ static void FORCEINLINE M_TAIL_ONE(BLASLONG K, const BLASLONG M, const BLASLONG
196196
}
197197
#endif
198198

199-
for (BLASLONG k = (K / 4); --k; ) {
199+
BLASLONG k = (K / 4);
200+
K &= 3;
201+
while (--k) {
200202
B00 = __riscv_vle32_v_f32m4(B, N * 4);
201203
B0 = __riscv_vget_v_f32m4_f32m1(B00, 0);
202204
B1 = __riscv_vget_v_f32m4_f32m1(B00, 1);
@@ -257,8 +259,6 @@ static void FORCEINLINE M_TAIL_ONE(BLASLONG K, const BLASLONG M, const BLASLONG
257259
result6 = __riscv_vfadd_vv_f32m1(result6, resultA, N);
258260
resultE = __riscv_vfadd_vv_f32m1(resultE, result6, N);
259261
}
260-
261-
K &= 3;
262262
} else {
263263
if (M & 2) {
264264
resultC = __riscv_vreinterpret_v_u32m1_f32m1(__riscv_vmv_v_x_u32m1(0, N));
@@ -373,7 +373,9 @@ static void FORCEINLINE M_TAIL_ONE(BLASLONG K, const BLASLONG M, const BLASLONG
373373
#endif
374374
B += (N * 2);
375375

376-
for (BLASLONG k = (K / 2); --k; ) {
376+
BLASLONG k = (K / 2);
377+
K &= 1;
378+
while (--k) {
377379
if (!S2) {
378380
B00 = __riscv_vle32_v_f32m2(B, N * 2);
379381
B0 = __riscv_vget_v_f32m2_f32m1(B00, 0);
@@ -497,8 +499,6 @@ static void FORCEINLINE M_TAIL_ONE(BLASLONG K, const BLASLONG M, const BLASLONG
497499
if (M & 1) {
498500
resultE = __riscv_vfadd_vv_f32m1(resultE, result6, N);
499501
}
500-
501-
K &= 1;
502502
} else {
503503
if (M == 8) {
504504
result0 = __riscv_vreinterpret_v_u32m1_f32m1(__riscv_vmv_v_x_u32m1(0, N));
@@ -1501,21 +1501,16 @@ static void FORCEINLINE N_TAIL_ONE(BLASLONG K, BLASLONG M, const BLASLONG N, FLO
15011501
B04 = B + ((N & 6) * K);
15021502
}
15031503
#endif
1504-
FLOAT K2;
15051504
#ifdef GEMM_BOTTOM_CHUNK
15061505
FLOAT K3;
15071506
if (N == 1) {
15081507
K3 = (K / 8);
1509-
K &= 7;
1508+
K = (K & 7) + 1;
15101509
} else if (N <= 4) {
15111510
K3 = (K / 2);
1512-
K &= 1;
1513-
} else
1514-
#endif
1515-
{
1516-
K--;
1511+
K = (K & 1) + 1;
15171512
}
1518-
K2 = K;
1513+
#endif
15191514
do {
15201515
FLOAT B0, B1, B2, B3, B4, B5, B6;
15211516
#ifdef GEMM_NEW_PACKING
@@ -1960,7 +1955,7 @@ static void FORCEINLINE N_TAIL_ONE(BLASLONG K, BLASLONG M, const BLASLONG N, FLO
19601955
}
19611956
}
19621957

1963-
while (K--) {
1958+
for (BLASLONG k = K; --k; ) {
19641959
if (N & 4) {
19651960
B0 = B00[0];
19661961
B1 = B00[1];
@@ -2023,39 +2018,39 @@ static void FORCEINLINE N_TAIL_ONE(BLASLONG K, BLASLONG M, const BLASLONG N, FLO
20232018
vfloat32m1_t c8, c9, cA, cB, cC, cD;
20242019
vfloat32m2_t c00;
20252020
if (N & 4) {
2026-
c00 = __riscv_vle32_v_f32m2(*C, 16);
2021+
c00 = __riscv_vle32_v_f32m2(C0, 16);
20272022
c0 = __riscv_vget_v_f32m2_f32m1(c00, 0);
20282023
c1 = __riscv_vget_v_f32m2_f32m1(c00, 1);
2029-
*C += ldc;
2030-
c00 = __riscv_vle32_v_f32m2(*C, 16);
2024+
C0 += ldc;
2025+
c00 = __riscv_vle32_v_f32m2(C0, 16);
20312026
c2 = __riscv_vget_v_f32m2_f32m1(c00, 0);
20322027
c3 = __riscv_vget_v_f32m2_f32m1(c00, 1);
2033-
*C += ldc;
2034-
c00 = __riscv_vle32_v_f32m2(*C, 16);
2028+
C0 += ldc;
2029+
c00 = __riscv_vle32_v_f32m2(C0, 16);
20352030
c4 = __riscv_vget_v_f32m2_f32m1(c00, 0);
20362031
c5 = __riscv_vget_v_f32m2_f32m1(c00, 1);
2037-
*C += ldc;
2038-
c00 = __riscv_vle32_v_f32m2(*C, 16);
2032+
C0 += ldc;
2033+
c00 = __riscv_vle32_v_f32m2(C0, 16);
20392034
c6 = __riscv_vget_v_f32m2_f32m1(c00, 0);
20402035
c7 = __riscv_vget_v_f32m2_f32m1(c00, 1);
20412036
if (N & 3) {
2042-
*C += ldc;
2037+
C0 += ldc;
20432038
}
20442039
}
20452040
if (N & 2) {
2046-
c00 = __riscv_vle32_v_f32m2(*C, 16);
2041+
c00 = __riscv_vle32_v_f32m2(C0, 16);
20472042
c8 = __riscv_vget_v_f32m2_f32m1(c00, 0);
20482043
c9 = __riscv_vget_v_f32m2_f32m1(c00, 1);
2049-
*C += ldc;
2050-
c00 = __riscv_vle32_v_f32m2(*C, 16);
2044+
C0 += ldc;
2045+
c00 = __riscv_vle32_v_f32m2(C0, 16);
20512046
cA = __riscv_vget_v_f32m2_f32m1(c00, 0);
20522047
cB = __riscv_vget_v_f32m2_f32m1(c00, 1);
20532048
if (N & 1) {
2054-
*C += ldc;
2049+
C0 += ldc;
20552050
}
20562051
}
20572052
if (N & 1) {
2058-
c00 = __riscv_vle32_v_f32m2(*C, 16);
2053+
c00 = __riscv_vle32_v_f32m2(C0, 16);
20592054
cC = __riscv_vget_v_f32m2_f32m1(c00, 0);
20602055
cD = __riscv_vget_v_f32m2_f32m1(c00, 1);
20612056
}
@@ -2081,40 +2076,38 @@ static void FORCEINLINE N_TAIL_ONE(BLASLONG K, BLASLONG M, const BLASLONG N, FLO
20812076
cD = __riscv_vfmacc_vf_f32m1(cD, alpha, resultD, 8);
20822077
}
20832078

2084-
*C = C0;
2079+
C0 = *C;
2080+
*C += 16;
20852081
if (N & 4) {
20862082
c00 = __riscv_vcreate_v_f32m1_f32m2(c0, c1);
2087-
__riscv_vse32_v_f32m2(*C, c00, 16);
2088-
*C += ldc;
2083+
__riscv_vse32_v_f32m2(C0, c00, 16);
2084+
C0 += ldc;
20892085
c00 = __riscv_vcreate_v_f32m1_f32m2(c2, c3);
2090-
__riscv_vse32_v_f32m2(*C, c00, 16);
2091-
*C += ldc;
2086+
__riscv_vse32_v_f32m2(C0, c00, 16);
2087+
C0 += ldc;
20922088
c00 = __riscv_vcreate_v_f32m1_f32m2(c4, c5);
2093-
__riscv_vse32_v_f32m2(*C, c00, 16);
2094-
*C += ldc;
2089+
__riscv_vse32_v_f32m2(C0, c00, 16);
2090+
C0 += ldc;
20952091
c00 = __riscv_vcreate_v_f32m1_f32m2(c6, c7);
2096-
__riscv_vse32_v_f32m2(*C, c00, 16);
2092+
__riscv_vse32_v_f32m2(C0, c00, 16);
20972093
if (N & 3) {
2098-
*C += ldc;
2094+
C0 += ldc;
20992095
}
21002096
}
21012097
if (N & 2) {
21022098
c00 = __riscv_vcreate_v_f32m1_f32m2(c8, c9);
2103-
__riscv_vse32_v_f32m2(*C, c00, 16);
2104-
*C += ldc;
2099+
__riscv_vse32_v_f32m2(C0, c00, 16);
2100+
C0 += ldc;
21052101
c00 = __riscv_vcreate_v_f32m1_f32m2(cA, cB);
2106-
__riscv_vse32_v_f32m2(*C, c00, 16);
2102+
__riscv_vse32_v_f32m2(C0, c00, 16);
21072103
if (N & 1) {
2108-
*C += ldc;
2104+
C0 += ldc;
21092105
}
21102106
}
21112107
if (N & 1) {
21122108
c00 = __riscv_vcreate_v_f32m1_f32m2(cC, cD);
2113-
__riscv_vse32_v_f32m2(*C, c00, 16);
2109+
__riscv_vse32_v_f32m2(C0, c00, 16);
21142110
}
2115-
2116-
*C = C0 + 16;
2117-
K = K2;
21182111
} while (--M);
21192112
}
21202113

@@ -2176,8 +2169,8 @@ static void FORCEINLINE N_TAIL(BLASLONG K, const BLASLONG M, const BLASLONG N, F
21762169

21772170
static void NM_TAIL(BLASLONG K, BLASLONG M, const BLASLONG m_edge, const BLASLONG N, const BLASLONG S, FLOAT alpha, FLOAT* A, FLOAT* B, FLOAT* C, BLASLONG ldc)
21782171
{
2179-
if (M / 16) {
2180-
N_TAIL(K, M / 16, N, alpha, &A, B, &C, ldc);
2172+
if (M) {
2173+
N_TAIL(K, M, N, alpha, &A, B, &C, ldc);
21812174
}
21822175
if (m_edge) {
21832176
if (N & 4) {
@@ -2216,10 +2209,10 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT* A, FLOAT* B, F
22162209

22172210
FLOAT *C01 = C;
22182211
FLOAT *A00 = A;
2219-
for (BLASLONG j=0; j<N/8; j+=1) {
2212+
for (BLASLONG j = (N / 8); j--; ) {
22202213

22212214
FLOAT *B00 = B;
2222-
for (BLASLONG i=0; i<M/16; i+=1) {
2215+
for (BLASLONG i = (M / 16); i--; ) {
22232216
B = B00;
22242217
FLOAT B0 = B[0];
22252218
FLOAT B1 = B[1];
@@ -2252,7 +2245,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT* A, FLOAT* B, F
22522245
vfloat32m1_t result14 = __riscv_vfmul_vf_f32m1( A0, B7, 8 );
22532246
vfloat32m1_t result15 = __riscv_vfmul_vf_f32m1( A1, B7, 8 );
22542247

2255-
for(BLASLONG k=1; k<K; k++) {
2248+
for (BLASLONG k = K; --k; ) {
22562249
B0 = B[0];
22572250
B1 = B[1];
22582251
B2 = B[2];
@@ -2364,7 +2357,7 @@ int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT alpha, FLOAT* A, FLOAT* B, F
23642357
// -- tails for N<=7
23652358

23662359
if (N & 7) {
2367-
NM_TAIL(K, M, m_edge, N, S, alpha, A, B, C, ldc);
2360+
NM_TAIL(K, M / 16, m_edge, N, S, alpha, A, B, C, ldc);
23682361
}
23692362

23702363
return 0;

0 commit comments

Comments
 (0)