@@ -8,11 +8,36 @@ use windows::Win32::Foundation::{S_FALSE, VARIANT_BOOL};
88use windows:: Win32 :: NetworkManagement :: WindowsFirewall :: * ;
99use windows:: Win32 :: System :: Com :: { CLSCTX_INPROC_SERVER , CoCreateInstance , CoInitializeEx , CoUninitialize , IDispatch , COINIT_APARTMENTTHREADED } ;
1010use windows:: Win32 :: System :: Ole :: IEnumVARIANT ;
11- use windows:: Win32 :: System :: Variant :: VARIANT ;
11+ use windows:: Win32 :: System :: Variant :: { VARIANT , VariantClear } ;
1212
1313use crate :: types:: { FirewallError , FirewallRule , FirewallRuleList , RuleAction , RuleDirection } ;
1414use crate :: util:: matches_any_filter;
1515
16+ /// RAII wrapper for VARIANT that automatically calls VariantClear on drop
17+ struct SafeVariant ( VARIANT ) ;
18+
19+ impl SafeVariant {
20+ fn new ( ) -> Self {
21+ Self ( VARIANT :: default ( ) )
22+ }
23+
24+ fn as_mut_ptr ( & mut self ) -> * mut VARIANT {
25+ & mut self . 0
26+ }
27+
28+ fn as_ref ( & self ) -> & VARIANT {
29+ & self . 0
30+ }
31+ }
32+
33+ impl Drop for SafeVariant {
34+ fn drop ( & mut self ) {
35+ if let Err ( e) = unsafe { VariantClear ( & mut self . 0 ) } {
36+ crate :: write_error ( & format ! ( "Warning: VariantClear failed with HRESULT: {:#010x}" , e. code( ) . 0 ) ) ;
37+ }
38+ }
39+ }
40+
1641struct ComGuard ;
1742
1843impl ComGuard {
@@ -55,20 +80,23 @@ impl FirewallStore {
5580 let mut results = Vec :: new ( ) ;
5681 loop {
5782 let mut fetched = 0u32 ;
58- let mut variant = [ VARIANT :: default ( ) ] ;
59- let hr = unsafe { enum_variant. Next ( & mut variant, & mut fetched) } ;
83+ let mut safe_variant = SafeVariant :: new ( ) ;
84+ let variant_slice = unsafe { std:: slice:: from_raw_parts_mut ( safe_variant. as_mut_ptr ( ) , 1 ) } ;
85+ let hr = unsafe { enum_variant. Next ( variant_slice, & mut fetched) } ;
6086 if hr == S_FALSE || fetched == 0 {
6187 break ;
6288 }
6389 hr. ok ( )
6490 . map_err ( |error| t ! ( "firewall.ruleEnumerationFailed" , error = error. to_string( ) ) . to_string ( ) ) ?;
6591
66- let dispatch = IDispatch :: try_from ( & variant [ 0 ] )
92+ let dispatch = IDispatch :: try_from ( safe_variant . as_ref ( ) )
6793 . map_err ( |error : windows:: core:: Error | t ! ( "firewall.ruleEnumerationFailed" , error = error. to_string( ) ) . to_string ( ) ) ?;
6894 let rule: INetFwRule = dispatch
6995 . cast ( )
7096 . map_err ( |error| t ! ( "firewall.ruleEnumerationFailed" , error = error. to_string( ) ) . to_string ( ) ) ?;
7197 results. push ( rule) ;
98+
99+ // SafeVariant will automatically call VariantClear when it goes out of scope
72100 }
73101
74102 Ok ( results)
@@ -169,6 +197,10 @@ fn profiles_from_mask(mask: i32) -> Vec<String> {
169197}
170198
171199fn profiles_to_mask ( values : & [ String ] ) -> Result < i32 , FirewallError > {
200+ if values. is_empty ( ) {
201+ return Ok ( NET_FW_PROFILE2_ALL . 0 ) ;
202+ }
203+
172204 let mut mask = 0 ;
173205 for value in values {
174206 match value. to_ascii_lowercase ( ) . as_str ( ) {
@@ -197,6 +229,10 @@ fn join_csv(value: &[String]) -> String {
197229}
198230
199231fn interface_types_to_string ( values : & [ String ] ) -> Result < String , FirewallError > {
232+ if values. is_empty ( ) {
233+ return Ok ( "All" . to_string ( ) ) ;
234+ }
235+
200236 let mut normalized = Vec :: new ( ) ;
201237 for value in values {
202238 match value. to_ascii_lowercase ( ) . as_str ( ) {
@@ -269,6 +305,12 @@ fn apply_rule_properties(rule: &INetFwRule, desired: &FirewallRule, existing_pro
269305 // the existing rule's protocol (if updating an existing rule).
270306 let effective_protocol = desired. protocol . or ( existing_protocol) ;
271307
308+ // If effective_protocol is None, read the current protocol from the rule.
309+ let effective_protocol = match effective_protocol {
310+ Some ( protocol) => Some ( protocol) ,
311+ None => Some ( unsafe { rule. Protocol ( ) } . map_err ( & err) ?) ,
312+ } ;
313+
272314 // Reject port specifications for protocols that don't support them (e.g. ICMP).
273315 // This must be checked regardless of whether the protocol itself was changed,
274316 // because the caller may only be setting local_ports or remote_ports.
0 commit comments