@@ -170,6 +170,19 @@ static void __user *apply_user_offset(
170170 return base + offset ;
171171}
172172
173+ struct user_ctxs {
174+ struct fpsimd_context __user * fpsimd ;
175+ u32 fpsimd_size ;
176+ struct sve_context __user * sve ;
177+ u32 sve_size ;
178+ struct tpidr2_context __user * tpidr2 ;
179+ u32 tpidr2_size ;
180+ struct za_context __user * za ;
181+ u32 za_size ;
182+ struct zt_context __user * zt ;
183+ u32 zt_size ;
184+ };
185+
173186static int preserve_fpsimd_context (struct fpsimd_context __user * ctx )
174187{
175188 struct user_fpsimd_state const * fpsimd =
@@ -188,25 +201,20 @@ static int preserve_fpsimd_context(struct fpsimd_context __user *ctx)
188201 return err ? - EFAULT : 0 ;
189202}
190203
191- static int restore_fpsimd_context (struct fpsimd_context __user * ctx )
204+ static int restore_fpsimd_context (struct user_ctxs * user )
192205{
193206 struct user_fpsimd_state fpsimd ;
194- __u32 magic , size ;
195207 int err = 0 ;
196208
197- /* check the magic/size information */
198- __get_user_error (magic , & ctx -> head .magic , err );
199- __get_user_error (size , & ctx -> head .size , err );
200- if (err )
201- return - EFAULT ;
202- if (magic != FPSIMD_MAGIC || size != sizeof (struct fpsimd_context ))
209+ /* check the size information */
210+ if (user -> fpsimd_size != sizeof (struct fpsimd_context ))
203211 return - EINVAL ;
204212
205213 /* copy the FP and status/control registers */
206- err = __copy_from_user (fpsimd .vregs , ctx -> vregs ,
214+ err = __copy_from_user (fpsimd .vregs , & ( user -> fpsimd -> vregs ) ,
207215 sizeof (fpsimd .vregs ));
208- __get_user_error (fpsimd .fpsr , & ctx -> fpsr , err );
209- __get_user_error (fpsimd .fpcr , & ctx -> fpcr , err );
216+ __get_user_error (fpsimd .fpsr , & ( user -> fpsimd -> fpsr ) , err );
217+ __get_user_error (fpsimd .fpcr , & ( user -> fpsimd -> fpcr ) , err );
210218
211219 clear_thread_flag (TIF_SVE );
212220 current -> thread .fp_type = FP_STATE_FPSIMD ;
@@ -219,14 +227,6 @@ static int restore_fpsimd_context(struct fpsimd_context __user *ctx)
219227}
220228
221229
222- struct user_ctxs {
223- struct fpsimd_context __user * fpsimd ;
224- struct sve_context __user * sve ;
225- struct tpidr2_context __user * tpidr2 ;
226- struct za_context __user * za ;
227- struct zt_context __user * zt ;
228- };
229-
230230#ifdef CONFIG_ARM64_SVE
231231
232232static int preserve_sve_context (struct sve_context __user * ctx )
@@ -271,15 +271,20 @@ static int preserve_sve_context(struct sve_context __user *ctx)
271271
272272static int restore_sve_fpsimd_context (struct user_ctxs * user )
273273{
274- int err ;
274+ int err = 0 ;
275275 unsigned int vl , vq ;
276276 struct user_fpsimd_state fpsimd ;
277- struct sve_context sve ;
277+ u16 user_vl , flags ;
278278
279- if (__copy_from_user (& sve , user -> sve , sizeof (sve )))
280- return - EFAULT ;
279+ if (user -> sve_size < sizeof (* user -> sve ))
280+ return - EINVAL ;
281+
282+ __get_user_error (user_vl , & (user -> sve -> vl ), err );
283+ __get_user_error (flags , & (user -> sve -> flags ), err );
284+ if (err )
285+ return err ;
281286
282- if (sve . flags & SVE_SIG_FLAG_SM ) {
287+ if (flags & SVE_SIG_FLAG_SM ) {
283288 if (!system_supports_sme ())
284289 return - EINVAL ;
285290
@@ -291,19 +296,19 @@ static int restore_sve_fpsimd_context(struct user_ctxs *user)
291296 vl = task_get_sve_vl (current );
292297 }
293298
294- if (sve . vl != vl )
299+ if (user_vl != vl )
295300 return - EINVAL ;
296301
297- if (sve . head . size < = sizeof (* user -> sve )) {
302+ if (user -> sve_size = = sizeof (* user -> sve )) {
298303 clear_thread_flag (TIF_SVE );
299304 current -> thread .svcr &= ~SVCR_SM_MASK ;
300305 current -> thread .fp_type = FP_STATE_FPSIMD ;
301306 goto fpsimd_only ;
302307 }
303308
304- vq = sve_vq_from_vl (sve . vl );
309+ vq = sve_vq_from_vl (vl );
305310
306- if (sve . head . size < SVE_SIG_CONTEXT_SIZE (vq ))
311+ if (user -> sve_size < SVE_SIG_CONTEXT_SIZE (vq ))
307312 return - EINVAL ;
308313
309314 /*
@@ -329,7 +334,7 @@ static int restore_sve_fpsimd_context(struct user_ctxs *user)
329334 if (err )
330335 return - EFAULT ;
331336
332- if (sve . flags & SVE_SIG_FLAG_SM )
337+ if (flags & SVE_SIG_FLAG_SM )
333338 current -> thread .svcr |= SVCR_SM_MASK ;
334339 else
335340 set_thread_flag (TIF_SVE );
@@ -383,7 +388,9 @@ static int restore_tpidr2_context(struct user_ctxs *user)
383388 u64 tpidr2_el0 ;
384389 int err = 0 ;
385390
386- /* Magic and size were validated deciding to call this function */
391+ if (user -> tpidr2_size != sizeof (* user -> tpidr2 ))
392+ return - EINVAL ;
393+
387394 __get_user_error (tpidr2_el0 , & user -> tpidr2 -> tpidr2 , err );
388395 if (!err )
389396 current -> thread .tpidr2_el0 = tpidr2_el0 ;
@@ -428,24 +435,28 @@ static int preserve_za_context(struct za_context __user *ctx)
428435
429436static int restore_za_context (struct user_ctxs * user )
430437{
431- int err ;
438+ int err = 0 ;
432439 unsigned int vq ;
433- struct za_context za ;
440+ u16 user_vl ;
434441
435- if (__copy_from_user (& za , user -> za , sizeof (za )))
436- return - EFAULT ;
442+ if (user -> za_size < sizeof (* user -> za ))
443+ return - EINVAL ;
444+
445+ __get_user_error (user_vl , & (user -> za -> vl ), err );
446+ if (err )
447+ return err ;
437448
438- if (za . vl != task_get_sme_vl (current ))
449+ if (user_vl != task_get_sme_vl (current ))
439450 return - EINVAL ;
440451
441- if (za . head . size < = sizeof (* user -> za )) {
452+ if (user -> za_size = = sizeof (* user -> za )) {
442453 current -> thread .svcr &= ~SVCR_ZA_MASK ;
443454 return 0 ;
444455 }
445456
446- vq = sve_vq_from_vl (za . vl );
457+ vq = sve_vq_from_vl (user_vl );
447458
448- if (za . head . size < ZA_SIG_CONTEXT_SIZE (vq ))
459+ if (user -> za_size < ZA_SIG_CONTEXT_SIZE (vq ))
449460 return - EINVAL ;
450461
451462 /*
@@ -510,19 +521,19 @@ static int preserve_zt_context(struct zt_context __user *ctx)
510521static int restore_zt_context (struct user_ctxs * user )
511522{
512523 int err ;
513- struct zt_context zt ;
524+ u16 nregs ;
514525
515526 /* ZA must be restored first for this check to be valid */
516527 if (!thread_za_enabled (& current -> thread ))
517528 return - EINVAL ;
518529
519- if (__copy_from_user (& zt , user -> zt , sizeof (zt )))
520- return - EFAULT ;
521-
522- if (zt .nregs != 1 )
530+ if (user -> zt_size != ZT_SIG_CONTEXT_SIZE (1 ))
523531 return - EINVAL ;
524532
525- if (zt .head .size != ZT_SIG_CONTEXT_SIZE (zt .nregs ))
533+ if (__copy_from_user (& nregs , & (user -> zt -> nregs ), sizeof (nregs )))
534+ return - EFAULT ;
535+
536+ if (nregs != 1 )
526537 return - EINVAL ;
527538
528539 /*
@@ -615,10 +626,8 @@ static int parse_user_sigframe(struct user_ctxs *user,
615626 if (user -> fpsimd )
616627 goto invalid ;
617628
618- if (size < sizeof (* user -> fpsimd ))
619- goto invalid ;
620-
621629 user -> fpsimd = (struct fpsimd_context __user * )head ;
630+ user -> fpsimd_size = size ;
622631 break ;
623632
624633 case ESR_MAGIC :
@@ -632,10 +641,8 @@ static int parse_user_sigframe(struct user_ctxs *user,
632641 if (user -> sve )
633642 goto invalid ;
634643
635- if (size < sizeof (* user -> sve ))
636- goto invalid ;
637-
638644 user -> sve = (struct sve_context __user * )head ;
645+ user -> sve_size = size ;
639646 break ;
640647
641648 case TPIDR2_MAGIC :
@@ -645,10 +652,8 @@ static int parse_user_sigframe(struct user_ctxs *user,
645652 if (user -> tpidr2 )
646653 goto invalid ;
647654
648- if (size != sizeof (* user -> tpidr2 ))
649- goto invalid ;
650-
651655 user -> tpidr2 = (struct tpidr2_context __user * )head ;
656+ user -> tpidr2_size = size ;
652657 break ;
653658
654659 case ZA_MAGIC :
@@ -658,10 +663,8 @@ static int parse_user_sigframe(struct user_ctxs *user,
658663 if (user -> za )
659664 goto invalid ;
660665
661- if (size < sizeof (* user -> za ))
662- goto invalid ;
663-
664666 user -> za = (struct za_context __user * )head ;
667+ user -> za_size = size ;
665668 break ;
666669
667670 case ZT_MAGIC :
@@ -671,10 +674,8 @@ static int parse_user_sigframe(struct user_ctxs *user,
671674 if (user -> zt )
672675 goto invalid ;
673676
674- if (size < sizeof (* user -> zt ))
675- goto invalid ;
676-
677677 user -> zt = (struct zt_context __user * )head ;
678+ user -> zt_size = size ;
678679 break ;
679680
680681 case EXTRA_MAGIC :
@@ -793,7 +794,7 @@ static int restore_sigframe(struct pt_regs *regs,
793794 if (user .sve )
794795 err = restore_sve_fpsimd_context (& user );
795796 else
796- err = restore_fpsimd_context (user . fpsimd );
797+ err = restore_fpsimd_context (& user );
797798 }
798799
799800 if (err == 0 && system_supports_sme () && user .tpidr2 )
0 commit comments