66#include "allowedips.h"
77#include "peer.h"
88
9+ static struct kmem_cache * node_cache ;
10+
911static void swap_endian (u8 * dst , const u8 * src , u8 bits )
1012{
1113 if (bits == 32 ) {
@@ -28,8 +30,11 @@ static void copy_and_assign_cidr(struct allowedips_node *node, const u8 *src,
2830 node -> bitlen = bits ;
2931 memcpy (node -> bits , src , bits / 8U );
3032}
31- #define CHOOSE_NODE (parent , key ) \
32- parent->bit[(key[parent->bit_at_a] >> parent->bit_at_b) & 1]
33+
34+ static inline u8 choose (struct allowedips_node * node , const u8 * key )
35+ {
36+ return (key [node -> bit_at_a ] >> node -> bit_at_b ) & 1 ;
37+ }
3338
3439static void push_rcu (struct allowedips_node * * stack ,
3540 struct allowedips_node __rcu * p , unsigned int * len )
@@ -40,6 +45,11 @@ static void push_rcu(struct allowedips_node **stack,
4045 }
4146}
4247
48+ static void node_free_rcu (struct rcu_head * rcu )
49+ {
50+ kmem_cache_free (node_cache , container_of (rcu , struct allowedips_node , rcu ));
51+ }
52+
4353static void root_free_rcu (struct rcu_head * rcu )
4454{
4555 struct allowedips_node * node , * stack [128 ] = {
@@ -49,7 +59,7 @@ static void root_free_rcu(struct rcu_head *rcu)
4959 while (len > 0 && (node = stack [-- len ])) {
5060 push_rcu (stack , node -> bit [0 ], & len );
5161 push_rcu (stack , node -> bit [1 ], & len );
52- kfree ( node );
62+ kmem_cache_free ( node_cache , node );
5363 }
5464}
5565
@@ -66,60 +76,6 @@ static void root_remove_peer_lists(struct allowedips_node *root)
6676 }
6777}
6878
69- static void walk_remove_by_peer (struct allowedips_node __rcu * * top ,
70- struct wg_peer * peer , struct mutex * lock )
71- {
72- #define REF (p ) rcu_access_pointer(p)
73- #define DEREF (p ) rcu_dereference_protected(*(p), lockdep_is_held(lock))
74- #define PUSH (p ) ({ \
75- WARN_ON(IS_ENABLED(DEBUG) && len >= 128); \
76- stack[len++] = p; \
77- })
78-
79- struct allowedips_node __rcu * * stack [128 ], * * nptr ;
80- struct allowedips_node * node , * prev ;
81- unsigned int len ;
82-
83- if (unlikely (!peer || !REF (* top )))
84- return ;
85-
86- for (prev = NULL , len = 0 , PUSH (top ); len > 0 ; prev = node ) {
87- nptr = stack [len - 1 ];
88- node = DEREF (nptr );
89- if (!node ) {
90- -- len ;
91- continue ;
92- }
93- if (!prev || REF (prev -> bit [0 ]) == node ||
94- REF (prev -> bit [1 ]) == node ) {
95- if (REF (node -> bit [0 ]))
96- PUSH (& node -> bit [0 ]);
97- else if (REF (node -> bit [1 ]))
98- PUSH (& node -> bit [1 ]);
99- } else if (REF (node -> bit [0 ]) == prev ) {
100- if (REF (node -> bit [1 ]))
101- PUSH (& node -> bit [1 ]);
102- } else {
103- if (rcu_dereference_protected (node -> peer ,
104- lockdep_is_held (lock )) == peer ) {
105- RCU_INIT_POINTER (node -> peer , NULL );
106- list_del_init (& node -> peer_list );
107- if (!node -> bit [0 ] || !node -> bit [1 ]) {
108- rcu_assign_pointer (* nptr , DEREF (
109- & node -> bit [!REF (node -> bit [0 ])]));
110- kfree_rcu (node , rcu );
111- node = DEREF (nptr );
112- }
113- }
114- -- len ;
115- }
116- }
117-
118- #undef REF
119- #undef DEREF
120- #undef PUSH
121- }
122-
12379static unsigned int fls128 (u64 a , u64 b )
12480{
12581 return a ? fls64 (a ) + 64U : fls64 (b );
@@ -159,7 +115,7 @@ static struct allowedips_node *find_node(struct allowedips_node *trie, u8 bits,
159115 found = node ;
160116 if (node -> cidr == bits )
161117 break ;
162- node = rcu_dereference_bh (CHOOSE_NODE (node , key ));
118+ node = rcu_dereference_bh (node -> bit [ choose (node , key )] );
163119 }
164120 return found ;
165121}
@@ -191,8 +147,7 @@ static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key,
191147 u8 cidr , u8 bits , struct allowedips_node * * rnode ,
192148 struct mutex * lock )
193149{
194- struct allowedips_node * node = rcu_dereference_protected (trie ,
195- lockdep_is_held (lock ));
150+ struct allowedips_node * node = rcu_dereference_protected (trie , lockdep_is_held (lock ));
196151 struct allowedips_node * parent = NULL ;
197152 bool exact = false;
198153
@@ -202,13 +157,24 @@ static bool node_placement(struct allowedips_node __rcu *trie, const u8 *key,
202157 exact = true;
203158 break ;
204159 }
205- node = rcu_dereference_protected (CHOOSE_NODE (parent , key ),
206- lockdep_is_held (lock ));
160+ node = rcu_dereference_protected (parent -> bit [choose (parent , key )], lockdep_is_held (lock ));
207161 }
208162 * rnode = parent ;
209163 return exact ;
210164}
211165
166+ static inline void connect_node (struct allowedips_node * * parent , u8 bit , struct allowedips_node * node )
167+ {
168+ node -> parent_bit_packed = (unsigned long )parent | bit ;
169+ rcu_assign_pointer (* parent , node );
170+ }
171+
172+ static inline void choose_and_connect_node (struct allowedips_node * parent , struct allowedips_node * node )
173+ {
174+ u8 bit = choose (parent , node -> bits );
175+ connect_node (& parent -> bit [bit ], bit , node );
176+ }
177+
212178static int add (struct allowedips_node __rcu * * trie , u8 bits , const u8 * key ,
213179 u8 cidr , struct wg_peer * peer , struct mutex * lock )
214180{
@@ -218,13 +184,13 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
218184 return - EINVAL ;
219185
220186 if (!rcu_access_pointer (* trie )) {
221- node = kzalloc ( sizeof ( * node ) , GFP_KERNEL );
187+ node = kmem_cache_zalloc ( node_cache , GFP_KERNEL );
222188 if (unlikely (!node ))
223189 return - ENOMEM ;
224190 RCU_INIT_POINTER (node -> peer , peer );
225191 list_add_tail (& node -> peer_list , & peer -> allowedips_list );
226192 copy_and_assign_cidr (node , key , cidr , bits );
227- rcu_assign_pointer ( * trie , node );
193+ connect_node ( trie , 2 , node );
228194 return 0 ;
229195 }
230196 if (node_placement (* trie , key , cidr , bits , & node , lock )) {
@@ -233,7 +199,7 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
233199 return 0 ;
234200 }
235201
236- newnode = kzalloc ( sizeof ( * newnode ) , GFP_KERNEL );
202+ newnode = kmem_cache_zalloc ( node_cache , GFP_KERNEL );
237203 if (unlikely (!newnode ))
238204 return - ENOMEM ;
239205 RCU_INIT_POINTER (newnode -> peer , peer );
@@ -243,41 +209,40 @@ static int add(struct allowedips_node __rcu **trie, u8 bits, const u8 *key,
243209 if (!node ) {
244210 down = rcu_dereference_protected (* trie , lockdep_is_held (lock ));
245211 } else {
246- down = rcu_dereference_protected ( CHOOSE_NODE ( node , key ),
247- lockdep_is_held (lock ));
212+ const u8 bit = choose ( node , key );
213+ down = rcu_dereference_protected ( node -> bit [ bit ], lockdep_is_held (lock ));
248214 if (!down ) {
249- rcu_assign_pointer ( CHOOSE_NODE ( node , key ) , newnode );
215+ connect_node ( & node -> bit [ bit ], bit , newnode );
250216 return 0 ;
251217 }
252218 }
253219 cidr = min (cidr , common_bits (down , key , bits ));
254220 parent = node ;
255221
256222 if (newnode -> cidr == cidr ) {
257- rcu_assign_pointer ( CHOOSE_NODE ( newnode , down -> bits ) , down );
223+ choose_and_connect_node ( newnode , down );
258224 if (!parent )
259- rcu_assign_pointer ( * trie , newnode );
225+ connect_node ( trie , 2 , newnode );
260226 else
261- rcu_assign_pointer (CHOOSE_NODE (parent , newnode -> bits ),
262- newnode );
263- } else {
264- node = kzalloc (sizeof (* node ), GFP_KERNEL );
265- if (unlikely (!node )) {
266- list_del (& newnode -> peer_list );
267- kfree (newnode );
268- return - ENOMEM ;
269- }
270- INIT_LIST_HEAD (& node -> peer_list );
271- copy_and_assign_cidr (node , newnode -> bits , cidr , bits );
227+ choose_and_connect_node (parent , newnode );
228+ return 0 ;
229+ }
272230
273- rcu_assign_pointer (CHOOSE_NODE (node , down -> bits ), down );
274- rcu_assign_pointer (CHOOSE_NODE (node , newnode -> bits ), newnode );
275- if (!parent )
276- rcu_assign_pointer (* trie , node );
277- else
278- rcu_assign_pointer (CHOOSE_NODE (parent , node -> bits ),
279- node );
231+ node = kmem_cache_zalloc (node_cache , GFP_KERNEL );
232+ if (unlikely (!node )) {
233+ list_del (& newnode -> peer_list );
234+ kmem_cache_free (node_cache , newnode );
235+ return - ENOMEM ;
280236 }
237+ INIT_LIST_HEAD (& node -> peer_list );
238+ copy_and_assign_cidr (node , newnode -> bits , cidr , bits );
239+
240+ choose_and_connect_node (node , down );
241+ choose_and_connect_node (node , newnode );
242+ if (!parent )
243+ connect_node (trie , 2 , node );
244+ else
245+ choose_and_connect_node (parent , node );
281246 return 0 ;
282247}
283248
@@ -335,9 +300,41 @@ int wg_allowedips_insert_v6(struct allowedips *table, const struct in6_addr *ip,
335300void wg_allowedips_remove_by_peer (struct allowedips * table ,
336301 struct wg_peer * peer , struct mutex * lock )
337302{
303+ struct allowedips_node * node , * child , * * parent_bit , * parent , * tmp ;
304+ bool free_parent ;
305+
306+ if (list_empty (& peer -> allowedips_list ))
307+ return ;
338308 ++ table -> seq ;
339- walk_remove_by_peer (& table -> root4 , peer , lock );
340- walk_remove_by_peer (& table -> root6 , peer , lock );
309+ list_for_each_entry_safe (node , tmp , & peer -> allowedips_list , peer_list ) {
310+ list_del_init (& node -> peer_list );
311+ RCU_INIT_POINTER (node -> peer , NULL );
312+ if (node -> bit [0 ] && node -> bit [1 ])
313+ continue ;
314+ child = rcu_dereference_protected (node -> bit [!rcu_access_pointer (node -> bit [0 ])],
315+ lockdep_is_held (lock ));
316+ if (child )
317+ child -> parent_bit_packed = node -> parent_bit_packed ;
318+ parent_bit = (struct allowedips_node * * )(node -> parent_bit_packed & ~3UL );
319+ * parent_bit = child ;
320+ parent = (void * )parent_bit -
321+ offsetof(struct allowedips_node , bit [node -> parent_bit_packed & 1 ]);
322+ free_parent = !rcu_access_pointer (node -> bit [0 ]) &&
323+ !rcu_access_pointer (node -> bit [1 ]) &&
324+ (node -> parent_bit_packed & 3 ) <= 1 &&
325+ !rcu_access_pointer (parent -> peer );
326+ if (free_parent )
327+ child = rcu_dereference_protected (
328+ parent -> bit [!(node -> parent_bit_packed & 1 )],
329+ lockdep_is_held (lock ));
330+ call_rcu (& node -> rcu , node_free_rcu );
331+ if (!free_parent )
332+ continue ;
333+ if (child )
334+ child -> parent_bit_packed = parent -> parent_bit_packed ;
335+ * (struct allowedips_node * * )(parent -> parent_bit_packed & ~3UL ) = child ;
336+ call_rcu (& parent -> rcu , node_free_rcu );
337+ }
341338}
342339
343340int wg_allowedips_read_node (struct allowedips_node * node , u8 ip [16 ], u8 * cidr )
@@ -374,4 +371,16 @@ struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table,
374371 return NULL ;
375372}
376373
374+ int __init wg_allowedips_slab_init (void )
375+ {
376+ node_cache = KMEM_CACHE (allowedips_node , 0 );
377+ return node_cache ? 0 : - ENOMEM ;
378+ }
379+
380+ void wg_allowedips_slab_uninit (void )
381+ {
382+ rcu_barrier ();
383+ kmem_cache_destroy (node_cache );
384+ }
385+
377386#include "selftest/allowedips.c"
0 commit comments