@@ -631,6 +631,138 @@ static int genl_validate_ops(const struct genl_family *family)
631631 return 0 ;
632632}
633633
634+ static void * genl_sk_priv_alloc (struct genl_family * family )
635+ {
636+ void * priv ;
637+
638+ priv = kzalloc (family -> sock_priv_size , GFP_KERNEL );
639+ if (!priv )
640+ return ERR_PTR (- ENOMEM );
641+
642+ if (family -> sock_priv_init )
643+ family -> sock_priv_init (priv );
644+
645+ return priv ;
646+ }
647+
648+ static void genl_sk_priv_free (const struct genl_family * family , void * priv )
649+ {
650+ if (family -> sock_priv_destroy )
651+ family -> sock_priv_destroy (priv );
652+ kfree (priv );
653+ }
654+
655+ static int genl_sk_privs_alloc (struct genl_family * family )
656+ {
657+ if (!family -> sock_priv_size )
658+ return 0 ;
659+
660+ family -> sock_privs = kzalloc (sizeof (* family -> sock_privs ), GFP_KERNEL );
661+ if (!family -> sock_privs )
662+ return - ENOMEM ;
663+ xa_init (family -> sock_privs );
664+ return 0 ;
665+ }
666+
667+ static void genl_sk_privs_free (const struct genl_family * family )
668+ {
669+ unsigned long id ;
670+ void * priv ;
671+
672+ if (!family -> sock_priv_size )
673+ return ;
674+
675+ xa_for_each (family -> sock_privs , id , priv )
676+ genl_sk_priv_free (family , priv );
677+
678+ xa_destroy (family -> sock_privs );
679+ kfree (family -> sock_privs );
680+ }
681+
682+ static void genl_sk_priv_free_by_sock (struct genl_family * family ,
683+ struct sock * sk )
684+ {
685+ void * priv ;
686+
687+ if (!family -> sock_priv_size )
688+ return ;
689+ priv = xa_erase (family -> sock_privs , (unsigned long ) sk );
690+ if (!priv )
691+ return ;
692+ genl_sk_priv_free (family , priv );
693+ }
694+
695+ static void genl_release (struct sock * sk , unsigned long * groups )
696+ {
697+ struct genl_family * family ;
698+ unsigned int id ;
699+
700+ down_read (& cb_lock );
701+
702+ idr_for_each_entry (& genl_fam_idr , family , id )
703+ genl_sk_priv_free_by_sock (family , sk );
704+
705+ up_read (& cb_lock );
706+ }
707+
708+ /**
709+ * __genl_sk_priv_get - Get family private pointer for socket, if exists
710+ *
711+ * @family: family
712+ * @sk: socket
713+ *
714+ * Lookup a private memory for a Generic netlink family and specified socket.
715+ *
716+ * Caller should make sure this is called in RCU read locked section.
717+ *
718+ * Return: valid pointer on success, otherwise negative error value
719+ * encoded by ERR_PTR(), NULL in case priv does not exist.
720+ */
721+ void * __genl_sk_priv_get (struct genl_family * family , struct sock * sk )
722+ {
723+ if (WARN_ON_ONCE (!family -> sock_privs ))
724+ return ERR_PTR (- EINVAL );
725+ return xa_load (family -> sock_privs , (unsigned long ) sk );
726+ }
727+
728+ /**
729+ * genl_sk_priv_get - Get family private pointer for socket
730+ *
731+ * @family: family
732+ * @sk: socket
733+ *
734+ * Lookup a private memory for a Generic netlink family and specified socket.
735+ * Allocate the private memory in case it was not already done.
736+ *
737+ * Return: valid pointer on success, otherwise negative error value
738+ * encoded by ERR_PTR().
739+ */
740+ void * genl_sk_priv_get (struct genl_family * family , struct sock * sk )
741+ {
742+ void * priv , * old_priv ;
743+
744+ priv = __genl_sk_priv_get (family , sk );
745+ if (priv )
746+ return priv ;
747+
748+ /* priv for the family does not exist so far, create it. */
749+
750+ priv = genl_sk_priv_alloc (family );
751+ if (IS_ERR (priv ))
752+ return ERR_CAST (priv );
753+
754+ old_priv = xa_cmpxchg (family -> sock_privs , (unsigned long ) sk , NULL ,
755+ priv , GFP_KERNEL );
756+ if (old_priv ) {
757+ genl_sk_priv_free (family , priv );
758+ if (xa_is_err (old_priv ))
759+ return ERR_PTR (xa_err (old_priv ));
760+ /* Race happened, priv for the socket was already inserted. */
761+ return old_priv ;
762+ }
763+ return priv ;
764+ }
765+
634766/**
635767 * genl_register_family - register a generic netlink family
636768 * @family: generic netlink family
@@ -659,6 +791,10 @@ int genl_register_family(struct genl_family *family)
659791 goto errout_locked ;
660792 }
661793
794+ err = genl_sk_privs_alloc (family );
795+ if (err )
796+ goto errout_locked ;
797+
662798 /*
663799 * Sadly, a few cases need to be special-cased
664800 * due to them having previously abused the API
@@ -679,7 +815,7 @@ int genl_register_family(struct genl_family *family)
679815 start , end + 1 , GFP_KERNEL );
680816 if (family -> id < 0 ) {
681817 err = family -> id ;
682- goto errout_locked ;
818+ goto errout_sk_privs_free ;
683819 }
684820
685821 err = genl_validate_assign_mc_groups (family );
@@ -698,6 +834,8 @@ int genl_register_family(struct genl_family *family)
698834
699835errout_remove :
700836 idr_remove (& genl_fam_idr , family -> id );
837+ errout_sk_privs_free :
838+ genl_sk_privs_free (family );
701839errout_locked :
702840 genl_unlock_all ();
703841 return err ;
@@ -728,6 +866,9 @@ int genl_unregister_family(const struct genl_family *family)
728866 up_write (& cb_lock );
729867 wait_event (genl_sk_destructing_waitq ,
730868 atomic_read (& genl_sk_destructing_cnt ) == 0 );
869+
870+ genl_sk_privs_free (family );
871+
731872 genl_unlock ();
732873
733874 genl_ctrl_event (CTRL_CMD_DELFAMILY , family , NULL , 0 );
@@ -1708,6 +1849,7 @@ static int __net_init genl_pernet_init(struct net *net)
17081849 .input = genl_rcv ,
17091850 .flags = NL_CFG_F_NONROOT_RECV ,
17101851 .bind = genl_bind ,
1852+ .release = genl_release ,
17111853 };
17121854
17131855 /* we'll bump the group number right afterwards */
0 commit comments