@@ -25,6 +25,7 @@ class Netlink:
2525 NETLINK_ADD_MEMBERSHIP = 1
2626 NETLINK_CAP_ACK = 10
2727 NETLINK_EXT_ACK = 11
28+ NETLINK_GET_STRICT_CHK = 12
2829
2930 # Netlink message
3031 NLMSG_ERROR = 2
@@ -228,6 +229,9 @@ def __init__(self, msg, offset, attr_space=None):
228229 desc += f" ({ spec ['doc' ]} )"
229230 self .extack ['miss-type' ] = desc
230231
232+ def cmd (self ):
233+ return self .nl_type
234+
231235 def __repr__ (self ):
232236 msg = f"nl_len = { self .nl_len } ({ len (self .raw )} ) nl_flags = 0x{ self .nl_flags :x} nl_type = { self .nl_type } \n "
233237 if self .error :
@@ -322,6 +326,9 @@ def __init__(self, nl_msg):
322326 self .genl_cmd , self .genl_version , _ = struct .unpack_from ("BBH" , nl_msg .raw , 0 )
323327 self .raw = nl_msg .raw [4 :]
324328
329+ def cmd (self ):
330+ return self .genl_cmd
331+
325332 def __repr__ (self ):
326333 msg = repr (self .nl )
327334 msg += f"\t genl_cmd = { self .genl_cmd } genl_ver = { self .genl_version } \n "
@@ -330,9 +337,41 @@ def __repr__(self):
330337 return msg
331338
332339
333- class GenlFamily :
334- def __init__ (self , family_name ):
340+ class NetlinkProtocol :
341+ def __init__ (self , family_name , proto_num ):
335342 self .family_name = family_name
343+ self .proto_num = proto_num
344+
345+ def _message (self , nl_type , nl_flags , seq = None ):
346+ if seq is None :
347+ seq = random .randint (1 , 1024 )
348+ nlmsg = struct .pack ("HHII" , nl_type , nl_flags , seq , 0 )
349+ return nlmsg
350+
351+ def message (self , flags , command , version , seq = None ):
352+ return self ._message (command , flags , seq )
353+
354+ def _decode (self , nl_msg ):
355+ return nl_msg
356+
357+ def decode (self , ynl , nl_msg ):
358+ msg = self ._decode (nl_msg )
359+ fixed_header_size = 0
360+ if ynl :
361+ op = ynl .rsp_by_value [msg .cmd ()]
362+ fixed_header_size = ynl ._fixed_header_size (op )
363+ msg .raw_attrs = NlAttrs (msg .raw [fixed_header_size :])
364+ return msg
365+
366+ def get_mcast_id (self , mcast_name , mcast_groups ):
367+ if mcast_name not in mcast_groups :
368+ raise Exception (f'Multicast group "{ mcast_name } " not present in the spec' )
369+ return mcast_groups [mcast_name ].value
370+
371+
372+ class GenlProtocol (NetlinkProtocol ):
373+ def __init__ (self , family_name ):
374+ super ().__init__ (family_name , Netlink .NETLINK_GENERIC )
336375
337376 global genl_family_name_to_id
338377 if genl_family_name_to_id is None :
@@ -341,6 +380,19 @@ def __init__(self, family_name):
341380 self .genl_family = genl_family_name_to_id [family_name ]
342381 self .family_id = genl_family_name_to_id [family_name ]['id' ]
343382
383+ def message (self , flags , command , version , seq = None ):
384+ nlmsg = self ._message (self .family_id , flags , seq )
385+ genlmsg = struct .pack ("BBH" , command , version , 0 )
386+ return nlmsg + genlmsg
387+
388+ def _decode (self , nl_msg ):
389+ return GenlMsg (nl_msg )
390+
391+ def get_mcast_id (self , mcast_name , mcast_groups ):
392+ if mcast_name not in self .genl_family ['mcast' ]:
393+ raise Exception (f'Multicast group "{ mcast_name } " not present in the family' )
394+ return self .genl_family ['mcast' ][mcast_name ]
395+
344396
345397#
346398# YNL implementation details.
@@ -353,9 +405,19 @@ def __init__(self, def_path, schema=None):
353405
354406 self .include_raw = False
355407
356- self .sock = socket .socket (socket .AF_NETLINK , socket .SOCK_RAW , Netlink .NETLINK_GENERIC )
408+ try :
409+ if self .proto == "netlink-raw" :
410+ self .nlproto = NetlinkProtocol (self .yaml ['name' ],
411+ self .yaml ['protonum' ])
412+ else :
413+ self .nlproto = GenlProtocol (self .yaml ['name' ])
414+ except KeyError :
415+ raise Exception (f"Family '{ self .yaml ['name' ]} ' not supported by the kernel" )
416+
417+ self .sock = socket .socket (socket .AF_NETLINK , socket .SOCK_RAW , self .nlproto .proto_num )
357418 self .sock .setsockopt (Netlink .SOL_NETLINK , Netlink .NETLINK_CAP_ACK , 1 )
358419 self .sock .setsockopt (Netlink .SOL_NETLINK , Netlink .NETLINK_EXT_ACK , 1 )
420+ self .sock .setsockopt (Netlink .SOL_NETLINK , Netlink .NETLINK_GET_STRICT_CHK , 1 )
359421
360422 self .async_msg_ids = set ()
361423 self .async_msg_queue = []
@@ -368,18 +430,12 @@ def __init__(self, def_path, schema=None):
368430 bound_f = functools .partial (self ._op , op_name )
369431 setattr (self , op .ident_name , bound_f )
370432
371- try :
372- self .family = GenlFamily (self .yaml ['name' ])
373- except KeyError :
374- raise Exception (f"Family '{ self .yaml ['name' ]} ' not supported by the kernel" )
375433
376434 def ntf_subscribe (self , mcast_name ):
377- if mcast_name not in self .family .genl_family ['mcast' ]:
378- raise Exception (f'Multicast group "{ mcast_name } " not present in the family' )
379-
435+ mcast_id = self .nlproto .get_mcast_id (mcast_name , self .mcast_groups )
380436 self .sock .bind ((0 , 0 ))
381437 self .sock .setsockopt (Netlink .SOL_NETLINK , Netlink .NETLINK_ADD_MEMBERSHIP ,
382- self . family . genl_family [ 'mcast' ][ mcast_name ] )
438+ mcast_id )
383439
384440 def _add_attr (self , space , name , value ):
385441 try :
@@ -505,11 +561,9 @@ def _decode_extack(self, request, op, extack):
505561 if 'bad-attr-offs' not in extack :
506562 return
507563
508- genl_req = GenlMsg (NlMsg (request , 0 , op .attr_set ))
509- fixed_header_size = self ._fixed_header_size (op )
510- offset = 20 + fixed_header_size
511- path = self ._decode_extack_path (NlAttrs (genl_req .raw [fixed_header_size :]),
512- op .attr_set , offset ,
564+ msg = self .nlproto .decode (self , NlMsg (request , 0 , op .attr_set ))
565+ offset = 20 + self ._fixed_header_size (op )
566+ path = self ._decode_extack_path (msg .raw_attrs , op .attr_set , offset ,
513567 extack ['bad-attr-offs' ])
514568 if path :
515569 del extack ['bad-attr-offs' ]
@@ -539,14 +593,17 @@ def _decode_fixed_header(self, msg, name):
539593 fixed_header_attrs [m .name ] = value
540594 return fixed_header_attrs
541595
542- def handle_ntf (self , nl_msg , genl_msg ):
596+ def handle_ntf (self , decoded ):
543597 msg = dict ()
544598 if self .include_raw :
545- msg ['nlmsg' ] = nl_msg
546- msg ['genlmsg' ] = genl_msg
547- op = self .rsp_by_value [genl_msg .genl_cmd ]
599+ msg ['raw' ] = decoded
600+ op = self .rsp_by_value [decoded .cmd ()]
601+ attrs = self ._decode (decoded .raw_attrs , op .attr_set .name )
602+ if op .fixed_header :
603+ attrs .update (self ._decode_fixed_header (decoded , op .fixed_header ))
604+
548605 msg ['name' ] = op ['name' ]
549- msg ['msg' ] = self . _decode ( genl_msg . raw_attrs , op . attr_set . name )
606+ msg ['msg' ] = attrs
550607 self .async_msg_queue .append (msg )
551608
552609 def check_ntf (self ):
@@ -566,12 +623,12 @@ def check_ntf(self):
566623 print ("Netlink done while checking for ntf!?" )
567624 continue
568625
569- gm = GenlMsg ( nl_msg )
570- if gm . genl_cmd not in self .async_msg_ids :
571- print ("Unexpected msg id done while checking for ntf" , gm )
626+ decoded = self . nlproto . decode ( self , nl_msg )
627+ if decoded . cmd () not in self .async_msg_ids :
628+ print ("Unexpected msg id done while checking for ntf" , decoded )
572629 continue
573630
574- self .handle_ntf (nl_msg , gm )
631+ self .handle_ntf (decoded )
575632
576633 def operation_do_attributes (self , name ):
577634 """
@@ -592,7 +649,7 @@ def _op(self, method, vals, dump=False):
592649 nl_flags |= Netlink .NLM_F_DUMP
593650
594651 req_seq = random .randint (1024 , 65535 )
595- msg = _genl_msg ( self .family . family_id , nl_flags , op .req_value , 1 , req_seq )
652+ msg = self .nlproto . message ( nl_flags , op .req_value , 1 , req_seq )
596653 fixed_header_members = []
597654 if op .fixed_header :
598655 fixed_header_members = self .consts [op .fixed_header ].members
@@ -624,19 +681,20 @@ def _op(self, method, vals, dump=False):
624681 done = True
625682 break
626683
627- gm = GenlMsg (nl_msg )
684+ decoded = self .nlproto .decode (self , nl_msg )
685+
628686 # Check if this is a reply to our request
629- if nl_msg .nl_seq != req_seq or gm . genl_cmd != op .rsp_value :
630- if gm . genl_cmd in self .async_msg_ids :
631- self .handle_ntf (nl_msg , gm )
687+ if nl_msg .nl_seq != req_seq or decoded . cmd () != op .rsp_value :
688+ if decoded . cmd () in self .async_msg_ids :
689+ self .handle_ntf (decoded )
632690 continue
633691 else :
634- print ('Unexpected message: ' + repr (gm ))
692+ print ('Unexpected message: ' + repr (decoded ))
635693 continue
636694
637- rsp_msg = self ._decode (NlAttrs ( gm . raw ) , op .attr_set .name )
695+ rsp_msg = self ._decode (decoded . raw_attrs , op .attr_set .name )
638696 if op .fixed_header :
639- rsp_msg .update (self ._decode_fixed_header (gm , op .fixed_header ))
697+ rsp_msg .update (self ._decode_fixed_header (decoded , op .fixed_header ))
640698 rsp .append (rsp_msg )
641699
642700 if not rsp :
0 commit comments