Skip to content

Commit bc6d6a4

Browse files
author
Eric Biggers
committed
lib/crypto: x86/sha256: Add support for 2-way interleaved hashing
Add an implementation of sha256_finup_2x_arch() for x86_64. It interleaves the computation of two SHA-256 hashes using the x86 SHA-NI instructions. dm-verity and fs-verity will take advantage of this for greatly improved performance on capable CPUs. This increases the throughput of SHA-256 hashing 4096-byte messages by the following amounts on the following CPUs: Intel Ice Lake (server): 4% Intel Sapphire Rapids: 38% Intel Emerald Rapids: 38% AMD Zen 1 (Threadripper 1950X): 84% AMD Zen 4 (EPYC 9B14): 98% AMD Zen 5 (Ryzen 9 9950X): 64% For now, this seems to benefit AMD more than Intel. This seems to be because current AMD CPUs support concurrent execution of the SHA-NI instructions, but unfortunately current Intel CPUs don't, except for the sha256msg2 instruction. Hopefully future Intel CPUs will support SHA-NI on more execution ports. Zen 1 supports 2 concurrent sha256rnds2, and Zen 4 supports 4 concurrent sha256rnds2, which suggests that even better performance may be achievable on Zen 4 by interleaving more than two hashes. However, doing so poses a number of trade-offs, and furthermore Zen 5 goes back to supporting "only" 2 concurrent sha256rnds2. Reviewed-by: Ard Biesheuvel <ardb@kernel.org> Link: https://lore.kernel.org/r/20250915160819.140019-4-ebiggers@kernel.org Signed-off-by: Eric Biggers <ebiggers@kernel.org>
1 parent 34c3f1e commit bc6d6a4

2 files changed

Lines changed: 407 additions & 0 deletions

File tree

lib/crypto/x86/sha256-ni-asm.S

Lines changed: 368 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,374 @@ SYM_FUNC_START(sha256_ni_transform)
165165
RET
166166
SYM_FUNC_END(sha256_ni_transform)
167167

168+
#undef DIGEST_PTR
169+
#undef DATA_PTR
170+
#undef NUM_BLKS
171+
#undef SHA256CONSTANTS
172+
#undef MSG
173+
#undef STATE0
174+
#undef STATE1
175+
#undef MSG0
176+
#undef MSG1
177+
#undef MSG2
178+
#undef MSG3
179+
#undef TMP
180+
#undef SHUF_MASK
181+
#undef ABEF_SAVE
182+
#undef CDGH_SAVE
183+
184+
// parameters for sha256_ni_finup2x()
185+
#define CTX %rdi
186+
#define DATA1 %rsi
187+
#define DATA2 %rdx
188+
#define LEN %ecx
189+
#define LEN8 %cl
190+
#define LEN64 %rcx
191+
#define OUT1 %r8
192+
#define OUT2 %r9
193+
194+
// other scalar variables
195+
#define SHA256CONSTANTS %rax
196+
#define COUNT %r10
197+
#define COUNT32 %r10d
198+
#define FINAL_STEP %r11d
199+
200+
// rbx is used as a temporary.
201+
202+
#define MSG %xmm0 // sha256rnds2 implicit operand
203+
#define STATE0_A %xmm1
204+
#define STATE1_A %xmm2
205+
#define STATE0_B %xmm3
206+
#define STATE1_B %xmm4
207+
#define TMP_A %xmm5
208+
#define TMP_B %xmm6
209+
#define MSG0_A %xmm7
210+
#define MSG1_A %xmm8
211+
#define MSG2_A %xmm9
212+
#define MSG3_A %xmm10
213+
#define MSG0_B %xmm11
214+
#define MSG1_B %xmm12
215+
#define MSG2_B %xmm13
216+
#define MSG3_B %xmm14
217+
#define SHUF_MASK %xmm15
218+
219+
#define OFFSETOF_STATE 0 // offsetof(struct __sha256_ctx, state)
220+
#define OFFSETOF_BYTECOUNT 32 // offsetof(struct __sha256_ctx, bytecount)
221+
#define OFFSETOF_BUF 40 // offsetof(struct __sha256_ctx, buf)
222+
223+
// Do 4 rounds of SHA-256 for each of two messages (interleaved). m0_a and m0_b
224+
// contain the current 4 message schedule words for the first and second message
225+
// respectively.
226+
//
227+
// If not all the message schedule words have been computed yet, then this also
228+
// computes 4 more message schedule words for each message. m1_a-m3_a contain
229+
// the next 3 groups of 4 message schedule words for the first message, and
230+
// likewise m1_b-m3_b for the second. After consuming the current value of
231+
// m0_a, this macro computes the group after m3_a and writes it to m0_a, and
232+
// likewise for *_b. This means that the next (m0_a, m1_a, m2_a, m3_a) is the
233+
// current (m1_a, m2_a, m3_a, m0_a), and likewise for *_b, so the caller must
234+
// cycle through the registers accordingly.
235+
.macro do_4rounds_2x i, m0_a, m1_a, m2_a, m3_a, m0_b, m1_b, m2_b, m3_b
236+
movdqa (\i-32)*4(SHA256CONSTANTS), TMP_A
237+
movdqa TMP_A, TMP_B
238+
paddd \m0_a, TMP_A
239+
paddd \m0_b, TMP_B
240+
.if \i < 48
241+
sha256msg1 \m1_a, \m0_a
242+
sha256msg1 \m1_b, \m0_b
243+
.endif
244+
movdqa TMP_A, MSG
245+
sha256rnds2 STATE0_A, STATE1_A
246+
movdqa TMP_B, MSG
247+
sha256rnds2 STATE0_B, STATE1_B
248+
pshufd $0x0E, TMP_A, MSG
249+
sha256rnds2 STATE1_A, STATE0_A
250+
pshufd $0x0E, TMP_B, MSG
251+
sha256rnds2 STATE1_B, STATE0_B
252+
.if \i < 48
253+
movdqa \m3_a, TMP_A
254+
movdqa \m3_b, TMP_B
255+
palignr $4, \m2_a, TMP_A
256+
palignr $4, \m2_b, TMP_B
257+
paddd TMP_A, \m0_a
258+
paddd TMP_B, \m0_b
259+
sha256msg2 \m3_a, \m0_a
260+
sha256msg2 \m3_b, \m0_b
261+
.endif
262+
.endm
263+
264+
//
265+
// void sha256_ni_finup2x(const struct __sha256_ctx *ctx,
266+
// const u8 *data1, const u8 *data2, int len,
267+
// u8 out1[SHA256_DIGEST_SIZE],
268+
// u8 out2[SHA256_DIGEST_SIZE]);
269+
//
270+
// This function computes the SHA-256 digests of two messages |data1| and
271+
// |data2| that are both |len| bytes long, starting from the initial context
272+
// |ctx|. |len| must be at least SHA256_BLOCK_SIZE.
273+
//
274+
// The instructions for the two SHA-256 operations are interleaved. On many
275+
// CPUs, this is almost twice as fast as hashing each message individually due
276+
// to taking better advantage of the CPU's SHA-256 and SIMD throughput.
277+
//
278+
SYM_FUNC_START(sha256_ni_finup2x)
279+
// Allocate 128 bytes of stack space, 16-byte aligned.
280+
push %rbx
281+
push %rbp
282+
mov %rsp, %rbp
283+
sub $128, %rsp
284+
and $~15, %rsp
285+
286+
// Load the shuffle mask for swapping the endianness of 32-bit words.
287+
movdqa PSHUFFLE_BYTE_FLIP_MASK(%rip), SHUF_MASK
288+
289+
// Set up pointer to the round constants.
290+
lea K256+32*4(%rip), SHA256CONSTANTS
291+
292+
// Initially we're not processing the final blocks.
293+
xor FINAL_STEP, FINAL_STEP
294+
295+
// Load the initial state from ctx->state.
296+
movdqu OFFSETOF_STATE+0*16(CTX), STATE0_A // DCBA
297+
movdqu OFFSETOF_STATE+1*16(CTX), STATE1_A // HGFE
298+
movdqa STATE0_A, TMP_A
299+
punpcklqdq STATE1_A, STATE0_A // FEBA
300+
punpckhqdq TMP_A, STATE1_A // DCHG
301+
pshufd $0x1B, STATE0_A, STATE0_A // ABEF
302+
pshufd $0xB1, STATE1_A, STATE1_A // CDGH
303+
304+
// Load ctx->bytecount. Take the mod 64 of it to get the number of
305+
// bytes that are buffered in ctx->buf. Also save it in a register with
306+
// LEN added to it.
307+
mov LEN, LEN
308+
mov OFFSETOF_BYTECOUNT(CTX), %rbx
309+
lea (%rbx, LEN64, 1), COUNT
310+
and $63, %ebx
311+
jz .Lfinup2x_enter_loop // No bytes buffered?
312+
313+
// %ebx bytes (1 to 63) are currently buffered in ctx->buf. Load them
314+
// followed by the first 64 - %ebx bytes of data. Since LEN >= 64, we
315+
// just load 64 bytes from each of ctx->buf, DATA1, and DATA2
316+
// unconditionally and rearrange the data as needed.
317+
318+
movdqu OFFSETOF_BUF+0*16(CTX), MSG0_A
319+
movdqu OFFSETOF_BUF+1*16(CTX), MSG1_A
320+
movdqu OFFSETOF_BUF+2*16(CTX), MSG2_A
321+
movdqu OFFSETOF_BUF+3*16(CTX), MSG3_A
322+
movdqa MSG0_A, 0*16(%rsp)
323+
movdqa MSG1_A, 1*16(%rsp)
324+
movdqa MSG2_A, 2*16(%rsp)
325+
movdqa MSG3_A, 3*16(%rsp)
326+
327+
movdqu 0*16(DATA1), MSG0_A
328+
movdqu 1*16(DATA1), MSG1_A
329+
movdqu 2*16(DATA1), MSG2_A
330+
movdqu 3*16(DATA1), MSG3_A
331+
movdqu MSG0_A, 0*16(%rsp,%rbx)
332+
movdqu MSG1_A, 1*16(%rsp,%rbx)
333+
movdqu MSG2_A, 2*16(%rsp,%rbx)
334+
movdqu MSG3_A, 3*16(%rsp,%rbx)
335+
movdqa 0*16(%rsp), MSG0_A
336+
movdqa 1*16(%rsp), MSG1_A
337+
movdqa 2*16(%rsp), MSG2_A
338+
movdqa 3*16(%rsp), MSG3_A
339+
340+
movdqu 0*16(DATA2), MSG0_B
341+
movdqu 1*16(DATA2), MSG1_B
342+
movdqu 2*16(DATA2), MSG2_B
343+
movdqu 3*16(DATA2), MSG3_B
344+
movdqu MSG0_B, 0*16(%rsp,%rbx)
345+
movdqu MSG1_B, 1*16(%rsp,%rbx)
346+
movdqu MSG2_B, 2*16(%rsp,%rbx)
347+
movdqu MSG3_B, 3*16(%rsp,%rbx)
348+
movdqa 0*16(%rsp), MSG0_B
349+
movdqa 1*16(%rsp), MSG1_B
350+
movdqa 2*16(%rsp), MSG2_B
351+
movdqa 3*16(%rsp), MSG3_B
352+
353+
sub $64, %rbx // rbx = buffered - 64
354+
sub %rbx, DATA1 // DATA1 += 64 - buffered
355+
sub %rbx, DATA2 // DATA2 += 64 - buffered
356+
add %ebx, LEN // LEN += buffered - 64
357+
movdqa STATE0_A, STATE0_B
358+
movdqa STATE1_A, STATE1_B
359+
jmp .Lfinup2x_loop_have_data
360+
361+
.Lfinup2x_enter_loop:
362+
sub $64, LEN
363+
movdqa STATE0_A, STATE0_B
364+
movdqa STATE1_A, STATE1_B
365+
.Lfinup2x_loop:
366+
// Load the next two data blocks.
367+
movdqu 0*16(DATA1), MSG0_A
368+
movdqu 0*16(DATA2), MSG0_B
369+
movdqu 1*16(DATA1), MSG1_A
370+
movdqu 1*16(DATA2), MSG1_B
371+
movdqu 2*16(DATA1), MSG2_A
372+
movdqu 2*16(DATA2), MSG2_B
373+
movdqu 3*16(DATA1), MSG3_A
374+
movdqu 3*16(DATA2), MSG3_B
375+
add $64, DATA1
376+
add $64, DATA2
377+
.Lfinup2x_loop_have_data:
378+
// Convert the words of the data blocks from big endian.
379+
pshufb SHUF_MASK, MSG0_A
380+
pshufb SHUF_MASK, MSG0_B
381+
pshufb SHUF_MASK, MSG1_A
382+
pshufb SHUF_MASK, MSG1_B
383+
pshufb SHUF_MASK, MSG2_A
384+
pshufb SHUF_MASK, MSG2_B
385+
pshufb SHUF_MASK, MSG3_A
386+
pshufb SHUF_MASK, MSG3_B
387+
.Lfinup2x_loop_have_bswapped_data:
388+
389+
// Save the original state for each block.
390+
movdqa STATE0_A, 0*16(%rsp)
391+
movdqa STATE0_B, 1*16(%rsp)
392+
movdqa STATE1_A, 2*16(%rsp)
393+
movdqa STATE1_B, 3*16(%rsp)
394+
395+
// Do the SHA-256 rounds on each block.
396+
.irp i, 0, 16, 32, 48
397+
do_4rounds_2x (\i + 0), MSG0_A, MSG1_A, MSG2_A, MSG3_A, \
398+
MSG0_B, MSG1_B, MSG2_B, MSG3_B
399+
do_4rounds_2x (\i + 4), MSG1_A, MSG2_A, MSG3_A, MSG0_A, \
400+
MSG1_B, MSG2_B, MSG3_B, MSG0_B
401+
do_4rounds_2x (\i + 8), MSG2_A, MSG3_A, MSG0_A, MSG1_A, \
402+
MSG2_B, MSG3_B, MSG0_B, MSG1_B
403+
do_4rounds_2x (\i + 12), MSG3_A, MSG0_A, MSG1_A, MSG2_A, \
404+
MSG3_B, MSG0_B, MSG1_B, MSG2_B
405+
.endr
406+
407+
// Add the original state for each block.
408+
paddd 0*16(%rsp), STATE0_A
409+
paddd 1*16(%rsp), STATE0_B
410+
paddd 2*16(%rsp), STATE1_A
411+
paddd 3*16(%rsp), STATE1_B
412+
413+
// Update LEN and loop back if more blocks remain.
414+
sub $64, LEN
415+
jge .Lfinup2x_loop
416+
417+
// Check if any final blocks need to be handled.
418+
// FINAL_STEP = 2: all done
419+
// FINAL_STEP = 1: need to do count-only padding block
420+
// FINAL_STEP = 0: need to do the block with 0x80 padding byte
421+
cmp $1, FINAL_STEP
422+
jg .Lfinup2x_done
423+
je .Lfinup2x_finalize_countonly
424+
add $64, LEN
425+
jz .Lfinup2x_finalize_blockaligned
426+
427+
// Not block-aligned; 1 <= LEN <= 63 data bytes remain. Pad the block.
428+
// To do this, write the padding starting with the 0x80 byte to
429+
// &sp[64]. Then for each message, copy the last 64 data bytes to sp
430+
// and load from &sp[64 - LEN] to get the needed padding block. This
431+
// code relies on the data buffers being >= 64 bytes in length.
432+
mov $64, %ebx
433+
sub LEN, %ebx // ebx = 64 - LEN
434+
sub %rbx, DATA1 // DATA1 -= 64 - LEN
435+
sub %rbx, DATA2 // DATA2 -= 64 - LEN
436+
mov $0x80, FINAL_STEP // using FINAL_STEP as a temporary
437+
movd FINAL_STEP, MSG0_A
438+
pxor MSG1_A, MSG1_A
439+
movdqa MSG0_A, 4*16(%rsp)
440+
movdqa MSG1_A, 5*16(%rsp)
441+
movdqa MSG1_A, 6*16(%rsp)
442+
movdqa MSG1_A, 7*16(%rsp)
443+
cmp $56, LEN
444+
jge 1f // will COUNT spill into its own block?
445+
shl $3, COUNT
446+
bswap COUNT
447+
mov COUNT, 56(%rsp,%rbx)
448+
mov $2, FINAL_STEP // won't need count-only block
449+
jmp 2f
450+
1:
451+
mov $1, FINAL_STEP // will need count-only block
452+
2:
453+
movdqu 0*16(DATA1), MSG0_A
454+
movdqu 1*16(DATA1), MSG1_A
455+
movdqu 2*16(DATA1), MSG2_A
456+
movdqu 3*16(DATA1), MSG3_A
457+
movdqa MSG0_A, 0*16(%rsp)
458+
movdqa MSG1_A, 1*16(%rsp)
459+
movdqa MSG2_A, 2*16(%rsp)
460+
movdqa MSG3_A, 3*16(%rsp)
461+
movdqu 0*16(%rsp,%rbx), MSG0_A
462+
movdqu 1*16(%rsp,%rbx), MSG1_A
463+
movdqu 2*16(%rsp,%rbx), MSG2_A
464+
movdqu 3*16(%rsp,%rbx), MSG3_A
465+
466+
movdqu 0*16(DATA2), MSG0_B
467+
movdqu 1*16(DATA2), MSG1_B
468+
movdqu 2*16(DATA2), MSG2_B
469+
movdqu 3*16(DATA2), MSG3_B
470+
movdqa MSG0_B, 0*16(%rsp)
471+
movdqa MSG1_B, 1*16(%rsp)
472+
movdqa MSG2_B, 2*16(%rsp)
473+
movdqa MSG3_B, 3*16(%rsp)
474+
movdqu 0*16(%rsp,%rbx), MSG0_B
475+
movdqu 1*16(%rsp,%rbx), MSG1_B
476+
movdqu 2*16(%rsp,%rbx), MSG2_B
477+
movdqu 3*16(%rsp,%rbx), MSG3_B
478+
jmp .Lfinup2x_loop_have_data
479+
480+
// Prepare a padding block, either:
481+
//
482+
// {0x80, 0, 0, 0, ..., count (as __be64)}
483+
// This is for a block aligned message.
484+
//
485+
// { 0, 0, 0, 0, ..., count (as __be64)}
486+
// This is for a message whose length mod 64 is >= 56.
487+
//
488+
// Pre-swap the endianness of the words.
489+
.Lfinup2x_finalize_countonly:
490+
pxor MSG0_A, MSG0_A
491+
jmp 1f
492+
493+
.Lfinup2x_finalize_blockaligned:
494+
mov $0x80000000, %ebx
495+
movd %ebx, MSG0_A
496+
1:
497+
pxor MSG1_A, MSG1_A
498+
pxor MSG2_A, MSG2_A
499+
ror $29, COUNT
500+
movq COUNT, MSG3_A
501+
pslldq $8, MSG3_A
502+
movdqa MSG0_A, MSG0_B
503+
pxor MSG1_B, MSG1_B
504+
pxor MSG2_B, MSG2_B
505+
movdqa MSG3_A, MSG3_B
506+
mov $2, FINAL_STEP
507+
jmp .Lfinup2x_loop_have_bswapped_data
508+
509+
.Lfinup2x_done:
510+
// Write the two digests with all bytes in the correct order.
511+
movdqa STATE0_A, TMP_A
512+
movdqa STATE0_B, TMP_B
513+
punpcklqdq STATE1_A, STATE0_A // GHEF
514+
punpcklqdq STATE1_B, STATE0_B
515+
punpckhqdq TMP_A, STATE1_A // ABCD
516+
punpckhqdq TMP_B, STATE1_B
517+
pshufd $0xB1, STATE0_A, STATE0_A // HGFE
518+
pshufd $0xB1, STATE0_B, STATE0_B
519+
pshufd $0x1B, STATE1_A, STATE1_A // DCBA
520+
pshufd $0x1B, STATE1_B, STATE1_B
521+
pshufb SHUF_MASK, STATE0_A
522+
pshufb SHUF_MASK, STATE0_B
523+
pshufb SHUF_MASK, STATE1_A
524+
pshufb SHUF_MASK, STATE1_B
525+
movdqu STATE0_A, 1*16(OUT1)
526+
movdqu STATE0_B, 1*16(OUT2)
527+
movdqu STATE1_A, 0*16(OUT1)
528+
movdqu STATE1_B, 0*16(OUT2)
529+
530+
mov %rbp, %rsp
531+
pop %rbp
532+
pop %rbx
533+
RET
534+
SYM_FUNC_END(sha256_ni_finup2x)
535+
168536
.section .rodata.cst256.K256, "aM", @progbits, 256
169537
.align 64
170538
K256:

0 commit comments

Comments
 (0)