@@ -332,10 +332,99 @@ static u32 dev_iommu_get_max_pasids(struct device *dev)
332332 return min_t (u32 , max_pasids , dev -> iommu -> iommu_dev -> max_pasids );
333333}
334334
335+ /*
336+ * Init the dev->iommu and dev->iommu_group in the struct device and get the
337+ * driver probed
338+ */
339+ static int iommu_init_device (struct device * dev , const struct iommu_ops * ops )
340+ {
341+ struct iommu_device * iommu_dev ;
342+ struct iommu_group * group ;
343+ int ret ;
344+
345+ if (!dev_iommu_get (dev ))
346+ return - ENOMEM ;
347+
348+ if (!try_module_get (ops -> owner )) {
349+ ret = - EINVAL ;
350+ goto err_free ;
351+ }
352+
353+ iommu_dev = ops -> probe_device (dev );
354+ if (IS_ERR (iommu_dev )) {
355+ ret = PTR_ERR (iommu_dev );
356+ goto err_module_put ;
357+ }
358+
359+ group = ops -> device_group (dev );
360+ if (WARN_ON_ONCE (group == NULL ))
361+ group = ERR_PTR (- EINVAL );
362+ if (IS_ERR (group )) {
363+ ret = PTR_ERR (group );
364+ goto err_release ;
365+ }
366+ dev -> iommu_group = group ;
367+
368+ dev -> iommu -> iommu_dev = iommu_dev ;
369+ dev -> iommu -> max_pasids = dev_iommu_get_max_pasids (dev );
370+ if (ops -> is_attach_deferred )
371+ dev -> iommu -> attach_deferred = ops -> is_attach_deferred (dev );
372+ return 0 ;
373+
374+ err_release :
375+ if (ops -> release_device )
376+ ops -> release_device (dev );
377+ err_module_put :
378+ module_put (ops -> owner );
379+ err_free :
380+ dev_iommu_free (dev );
381+ return ret ;
382+ }
383+
384+ static void iommu_deinit_device (struct device * dev )
385+ {
386+ struct iommu_group * group = dev -> iommu_group ;
387+ const struct iommu_ops * ops = dev_iommu_ops (dev );
388+
389+ lockdep_assert_held (& group -> mutex );
390+
391+ /*
392+ * release_device() must stop using any attached domain on the device.
393+ * If there are still other devices in the group they are not effected
394+ * by this callback.
395+ *
396+ * The IOMMU driver must set the device to either an identity or
397+ * blocking translation and stop using any domain pointer, as it is
398+ * going to be freed.
399+ */
400+ if (ops -> release_device )
401+ ops -> release_device (dev );
402+
403+ /*
404+ * If this is the last driver to use the group then we must free the
405+ * domains before we do the module_put().
406+ */
407+ if (list_empty (& group -> devices )) {
408+ if (group -> default_domain ) {
409+ iommu_domain_free (group -> default_domain );
410+ group -> default_domain = NULL ;
411+ }
412+ if (group -> blocking_domain ) {
413+ iommu_domain_free (group -> blocking_domain );
414+ group -> blocking_domain = NULL ;
415+ }
416+ group -> domain = NULL ;
417+ }
418+
419+ /* Caller must put iommu_group */
420+ dev -> iommu_group = NULL ;
421+ module_put (ops -> owner );
422+ dev_iommu_free (dev );
423+ }
424+
335425static int __iommu_probe_device (struct device * dev , struct list_head * group_list )
336426{
337427 const struct iommu_ops * ops = dev -> bus -> iommu_ops ;
338- struct iommu_device * iommu_dev ;
339428 struct iommu_group * group ;
340429 static DEFINE_MUTEX (iommu_probe_device_lock );
341430 int ret ;
@@ -357,62 +446,30 @@ static int __iommu_probe_device(struct device *dev, struct list_head *group_list
357446 goto out_unlock ;
358447 }
359448
360- if (! dev_iommu_get ( dev )) {
361- ret = - ENOMEM ;
449+ ret = iommu_init_device ( dev , ops );
450+ if ( ret )
362451 goto out_unlock ;
363- }
364-
365- if (!try_module_get (ops -> owner )) {
366- ret = - EINVAL ;
367- goto err_free ;
368- }
369-
370- iommu_dev = ops -> probe_device (dev );
371- if (IS_ERR (iommu_dev )) {
372- ret = PTR_ERR (iommu_dev );
373- goto out_module_put ;
374- }
375-
376- dev -> iommu -> iommu_dev = iommu_dev ;
377- dev -> iommu -> max_pasids = dev_iommu_get_max_pasids (dev );
378- if (ops -> is_attach_deferred )
379- dev -> iommu -> attach_deferred = ops -> is_attach_deferred (dev );
380-
381- group = ops -> device_group (dev );
382- if (WARN_ON_ONCE (group == NULL ))
383- group = ERR_PTR (- EINVAL );
384- if (IS_ERR (group )) {
385- ret = PTR_ERR (group );
386- goto out_release ;
387- }
388452
453+ group = dev -> iommu_group ;
389454 ret = iommu_group_add_device (group , dev );
455+ mutex_lock (& group -> mutex );
390456 if (ret )
391457 goto err_put_group ;
392458
393- mutex_lock (& group -> mutex );
394459 if (group_list && !group -> default_domain && list_empty (& group -> entry ))
395460 list_add_tail (& group -> entry , group_list );
396461 mutex_unlock (& group -> mutex );
397462 iommu_group_put (group );
398463
399464 mutex_unlock (& iommu_probe_device_lock );
400- iommu_device_link (iommu_dev , dev );
465+ iommu_device_link (dev -> iommu -> iommu_dev , dev );
401466
402467 return 0 ;
403468
404469err_put_group :
470+ iommu_deinit_device (dev );
471+ mutex_unlock (& group -> mutex );
405472 iommu_group_put (group );
406- out_release :
407- if (ops -> release_device )
408- ops -> release_device (dev );
409-
410- out_module_put :
411- module_put (ops -> owner );
412-
413- err_free :
414- dev_iommu_free (dev );
415-
416473out_unlock :
417474 mutex_unlock (& iommu_probe_device_lock );
418475
@@ -491,63 +548,45 @@ static void __iommu_group_free_device(struct iommu_group *group,
491548
492549 kfree (grp_dev -> name );
493550 kfree (grp_dev );
494- dev -> iommu_group = NULL ;
495551}
496552
497- /*
498- * Remove the iommu_group from the struct device. The attached group must be put
499- * by the caller after releaseing the group->mutex.
500- */
553+ /* Remove the iommu_group from the struct device. */
501554static void __iommu_group_remove_device (struct device * dev )
502555{
503556 struct iommu_group * group = dev -> iommu_group ;
504557 struct group_device * device ;
505558
506- lockdep_assert_held (& group -> mutex );
559+ mutex_lock (& group -> mutex );
507560 for_each_group_device (group , device ) {
508561 if (device -> dev != dev )
509562 continue ;
510563
511564 list_del (& device -> list );
512565 __iommu_group_free_device (group , device );
513- /* Caller must put iommu_group */
514- return ;
566+ if (dev -> iommu && dev -> iommu -> iommu_dev )
567+ iommu_deinit_device (dev );
568+ else
569+ dev -> iommu_group = NULL ;
570+ goto out ;
515571 }
516572 WARN (true, "Corrupted iommu_group device_list" );
573+ out :
574+ mutex_unlock (& group -> mutex );
575+
576+ /* Pairs with the get in iommu_group_add_device() */
577+ iommu_group_put (group );
517578}
518579
519580static void iommu_release_device (struct device * dev )
520581{
521582 struct iommu_group * group = dev -> iommu_group ;
522- const struct iommu_ops * ops ;
523583
524584 if (!dev -> iommu || !group )
525585 return ;
526586
527587 iommu_device_unlink (dev -> iommu -> iommu_dev , dev );
528588
529- mutex_lock (& group -> mutex );
530589 __iommu_group_remove_device (dev );
531-
532- /*
533- * release_device() must stop using any attached domain on the device.
534- * If there are still other devices in the group they are not effected
535- * by this callback.
536- *
537- * The IOMMU driver must set the device to either an identity or
538- * blocking translation and stop using any domain pointer, as it is
539- * going to be freed.
540- */
541- ops = dev_iommu_ops (dev );
542- if (ops -> release_device )
543- ops -> release_device (dev );
544- mutex_unlock (& group -> mutex );
545-
546- /* Pairs with the get in iommu_group_add_device() */
547- iommu_group_put (group );
548-
549- module_put (ops -> owner );
550- dev_iommu_free (dev );
551590}
552591
553592static int __init iommu_set_def_domain_type (char * str )
@@ -808,10 +847,9 @@ static void iommu_group_release(struct kobject *kobj)
808847
809848 ida_free (& iommu_group_ida , group -> id );
810849
811- if (group -> default_domain )
812- iommu_domain_free (group -> default_domain );
813- if (group -> blocking_domain )
814- iommu_domain_free (group -> blocking_domain );
850+ /* Domains are free'd by iommu_deinit_device() */
851+ WARN_ON (group -> default_domain );
852+ WARN_ON (group -> blocking_domain );
815853
816854 kfree (group -> name );
817855 kfree (group );
@@ -1109,12 +1147,7 @@ void iommu_group_remove_device(struct device *dev)
11091147
11101148 dev_info (dev , "Removing from iommu group %d\n" , group -> id );
11111149
1112- mutex_lock (& group -> mutex );
11131150 __iommu_group_remove_device (dev );
1114- mutex_unlock (& group -> mutex );
1115-
1116- /* Pairs with the get in iommu_group_add_device() */
1117- iommu_group_put (group );
11181151}
11191152EXPORT_SYMBOL_GPL (iommu_group_remove_device );
11201153
0 commit comments