|
4 | 4 | //! |
5 | 5 | //! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com> |
6 | 6 |
|
7 | | -use std::collections::HashMap; |
8 | | -use std::fmt::Write; |
9 | | - |
10 | | -use proc_macro2::{Delimiter, Group, TokenStream, TokenTree}; |
11 | | - |
12 | | -pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { |
13 | | - let attr = attr.to_string(); |
14 | | - |
15 | | - if attr.is_empty() { |
16 | | - panic!("Missing test name in `#[kunit_tests(test_name)]` macro") |
17 | | - } |
18 | | - |
19 | | - if attr.len() > 255 { |
20 | | - panic!("The test suite name `{attr}` exceeds the maximum length of 255 bytes") |
| 7 | +use std::ffi::CString; |
| 8 | + |
| 9 | +use proc_macro2::TokenStream; |
| 10 | +use quote::{ |
| 11 | + format_ident, |
| 12 | + quote, |
| 13 | + ToTokens, // |
| 14 | +}; |
| 15 | +use syn::{ |
| 16 | + parse_quote, |
| 17 | + Error, |
| 18 | + Ident, |
| 19 | + Item, |
| 20 | + ItemMod, |
| 21 | + LitCStr, |
| 22 | + Result, // |
| 23 | +}; |
| 24 | + |
| 25 | +pub(crate) fn kunit_tests(test_suite: Ident, mut module: ItemMod) -> Result<TokenStream> { |
| 26 | + if test_suite.to_string().len() > 255 { |
| 27 | + return Err(Error::new_spanned( |
| 28 | + test_suite, |
| 29 | + "test suite names cannot exceed the maximum length of 255 bytes", |
| 30 | + )); |
21 | 31 | } |
22 | 32 |
|
23 | | - let mut tokens: Vec<_> = ts.into_iter().collect(); |
24 | | - |
25 | | - // Scan for the `mod` keyword. |
26 | | - tokens |
27 | | - .iter() |
28 | | - .find_map(|token| match token { |
29 | | - TokenTree::Ident(ident) => match ident.to_string().as_str() { |
30 | | - "mod" => Some(true), |
31 | | - _ => None, |
32 | | - }, |
33 | | - _ => None, |
34 | | - }) |
35 | | - .expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules"); |
36 | | - |
37 | | - // Retrieve the main body. The main body should be the last token tree. |
38 | | - let body = match tokens.pop() { |
39 | | - Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, |
40 | | - _ => panic!("Cannot locate main body of module"), |
| 33 | + // We cannot handle modules that defer to another file (e.g. `mod foo;`). |
| 34 | + let Some((module_brace, module_items)) = module.content.take() else { |
| 35 | + Err(Error::new_spanned( |
| 36 | + module, |
| 37 | + "`#[kunit_tests(test_name)]` attribute should only be applied to inline modules", |
| 38 | + ))? |
41 | 39 | }; |
42 | 40 |
|
43 | | - // Get the functions set as tests. Search for `[test]` -> `fn`. |
44 | | - let mut body_it = body.stream().into_iter(); |
45 | | - let mut tests = Vec::new(); |
46 | | - let mut attributes: HashMap<String, TokenStream> = HashMap::new(); |
47 | | - while let Some(token) = body_it.next() { |
48 | | - match token { |
49 | | - TokenTree::Punct(ref p) if p.as_char() == '#' => match body_it.next() { |
50 | | - Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => { |
51 | | - if let Some(TokenTree::Ident(name)) = g.stream().into_iter().next() { |
52 | | - // Collect attributes because we need to find which are tests. We also |
53 | | - // need to copy `cfg` attributes so tests can be conditionally enabled. |
54 | | - attributes |
55 | | - .entry(name.to_string()) |
56 | | - .or_default() |
57 | | - .extend([token, TokenTree::Group(g)]); |
58 | | - } |
59 | | - continue; |
60 | | - } |
61 | | - _ => (), |
62 | | - }, |
63 | | - TokenTree::Ident(i) if i == "fn" && attributes.contains_key("test") => { |
64 | | - if let Some(TokenTree::Ident(test_name)) = body_it.next() { |
65 | | - tests.push((test_name, attributes.remove("cfg").unwrap_or_default())) |
66 | | - } |
67 | | - } |
68 | | - |
69 | | - _ => (), |
70 | | - } |
71 | | - attributes.clear(); |
72 | | - } |
| 41 | + // Make the entire module gated behind `CONFIG_KUNIT`. |
| 42 | + module |
| 43 | + .attrs |
| 44 | + .insert(0, parse_quote!(#[cfg(CONFIG_KUNIT="y")])); |
73 | 45 |
|
74 | | - // Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration. |
75 | | - let config_kunit = "#[cfg(CONFIG_KUNIT=\"y\")]".to_owned().parse().unwrap(); |
76 | | - tokens.insert( |
77 | | - 0, |
78 | | - TokenTree::Group(Group::new(Delimiter::None, config_kunit)), |
79 | | - ); |
| 46 | + let mut processed_items = Vec::new(); |
| 47 | + let mut test_cases = Vec::new(); |
80 | 48 |
|
81 | 49 | // Generate the test KUnit test suite and a test case for each `#[test]`. |
| 50 | + // |
82 | 51 | // The code generated for the following test module: |
83 | 52 | // |
84 | 53 | // ``` |
@@ -110,98 +79,93 @@ pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { |
110 | 79 | // |
111 | 80 | // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES); |
112 | 81 | // ``` |
113 | | - let mut kunit_macros = "".to_owned(); |
114 | | - let mut test_cases = "".to_owned(); |
115 | | - let mut assert_macros = "".to_owned(); |
116 | | - let path = crate::helpers::file(); |
117 | | - let num_tests = tests.len(); |
118 | | - for (test, cfg_attr) in tests { |
119 | | - let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}"); |
120 | | - // Append any `cfg` attributes the user might have written on their tests so we don't |
121 | | - // attempt to call them when they are `cfg`'d out. An extra `use` is used here to reduce |
122 | | - // the length of the assert message. |
123 | | - let kunit_wrapper = format!( |
124 | | - r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit) |
125 | | - {{ |
126 | | - (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED; |
127 | | - {cfg_attr} {{ |
128 | | - (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS; |
129 | | - use ::kernel::kunit::is_test_result_ok; |
130 | | - assert!(is_test_result_ok({test}())); |
| 82 | + // |
| 83 | + // Non-function items (e.g. imports) are preserved. |
| 84 | + for item in module_items { |
| 85 | + let Item::Fn(mut f) = item else { |
| 86 | + processed_items.push(item); |
| 87 | + continue; |
| 88 | + }; |
| 89 | + |
| 90 | + // TODO: Replace below with `extract_if` when MSRV is bumped above 1.85. |
| 91 | + let before_len = f.attrs.len(); |
| 92 | + f.attrs.retain(|attr| !attr.path().is_ident("test")); |
| 93 | + if f.attrs.len() == before_len { |
| 94 | + processed_items.push(Item::Fn(f)); |
| 95 | + continue; |
| 96 | + } |
| 97 | + |
| 98 | + let test = f.sig.ident.clone(); |
| 99 | + |
| 100 | + // Retrieve `#[cfg]` applied on the function which needs to be present on derived items too. |
| 101 | + let cfg_attrs: Vec<_> = f |
| 102 | + .attrs |
| 103 | + .iter() |
| 104 | + .filter(|attr| attr.path().is_ident("cfg")) |
| 105 | + .cloned() |
| 106 | + .collect(); |
| 107 | + |
| 108 | + // Before the test, override usual `assert!` and `assert_eq!` macros with ones that call |
| 109 | + // KUnit instead. |
| 110 | + let test_str = test.to_string(); |
| 111 | + let path = crate::helpers::file(); |
| 112 | + processed_items.push(parse_quote! { |
| 113 | + #[allow(unused)] |
| 114 | + macro_rules! assert { |
| 115 | + ($cond:expr $(,)?) => {{ |
| 116 | + kernel::kunit_assert!(#test_str, #path, 0, $cond); |
| 117 | + }} |
| 118 | + } |
| 119 | + }); |
| 120 | + processed_items.push(parse_quote! { |
| 121 | + #[allow(unused)] |
| 122 | + macro_rules! assert_eq { |
| 123 | + ($left:expr, $right:expr $(,)?) => {{ |
| 124 | + kernel::kunit_assert_eq!(#test_str, #path, 0, $left, $right); |
131 | 125 | }} |
132 | | - }}"#, |
| 126 | + } |
| 127 | + }); |
| 128 | + |
| 129 | + // Add back the test item. |
| 130 | + processed_items.push(Item::Fn(f)); |
| 131 | + |
| 132 | + let kunit_wrapper_fn_name = format_ident!("kunit_rust_wrapper_{test}"); |
| 133 | + let test_cstr = LitCStr::new( |
| 134 | + &CString::new(test_str.as_str()).expect("identifier cannot contain NUL"), |
| 135 | + test.span(), |
133 | 136 | ); |
134 | | - writeln!(kunit_macros, "{kunit_wrapper}").unwrap(); |
135 | | - writeln!( |
136 | | - test_cases, |
137 | | - " ::kernel::kunit::kunit_case(::kernel::c_str!(\"{test}\"), {kunit_wrapper_fn_name})," |
138 | | - ) |
139 | | - .unwrap(); |
140 | | - writeln!( |
141 | | - assert_macros, |
142 | | - r#" |
143 | | -/// Overrides the usual [`assert!`] macro with one that calls KUnit instead. |
144 | | -#[allow(unused)] |
145 | | -macro_rules! assert {{ |
146 | | - ($cond:expr $(,)?) => {{{{ |
147 | | - kernel::kunit_assert!("{test}", "{path}", 0, $cond); |
148 | | - }}}} |
149 | | -}} |
150 | | -
|
151 | | -/// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead. |
152 | | -#[allow(unused)] |
153 | | -macro_rules! assert_eq {{ |
154 | | - ($left:expr, $right:expr $(,)?) => {{{{ |
155 | | - kernel::kunit_assert_eq!("{test}", "{path}", 0, $left, $right); |
156 | | - }}}} |
157 | | -}} |
158 | | - "# |
159 | | - ) |
160 | | - .unwrap(); |
161 | | - } |
| 137 | + processed_items.push(parse_quote! { |
| 138 | + unsafe extern "C" fn #kunit_wrapper_fn_name(_test: *mut ::kernel::bindings::kunit) { |
| 139 | + (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED; |
162 | 140 |
|
163 | | - writeln!(kunit_macros).unwrap(); |
164 | | - writeln!( |
165 | | - kunit_macros, |
166 | | - "static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = [\n{test_cases} ::kernel::kunit::kunit_case_null(),\n];", |
167 | | - num_tests + 1 |
168 | | - ) |
169 | | - .unwrap(); |
170 | | - |
171 | | - writeln!( |
172 | | - kunit_macros, |
173 | | - "::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);" |
174 | | - ) |
175 | | - .unwrap(); |
176 | | - |
177 | | - // Remove the `#[test]` macros. |
178 | | - // We do this at a token level, in order to preserve span information. |
179 | | - let mut new_body = vec![]; |
180 | | - let mut body_it = body.stream().into_iter(); |
181 | | - |
182 | | - while let Some(token) = body_it.next() { |
183 | | - match token { |
184 | | - TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() { |
185 | | - Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (), |
186 | | - Some(next) => { |
187 | | - new_body.extend([token, next]); |
188 | | - } |
189 | | - _ => { |
190 | | - new_body.push(token); |
| 141 | + // Append any `cfg` attributes the user might have written on their tests so we |
| 142 | + // don't attempt to call them when they are `cfg`'d out. An extra `use` is used |
| 143 | + // here to reduce the length of the assert message. |
| 144 | + #(#cfg_attrs)* |
| 145 | + { |
| 146 | + (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS; |
| 147 | + use ::kernel::kunit::is_test_result_ok; |
| 148 | + assert!(is_test_result_ok(#test())); |
191 | 149 | } |
192 | | - }, |
193 | | - _ => { |
194 | | - new_body.push(token); |
195 | 150 | } |
196 | | - } |
197 | | - } |
| 151 | + }); |
198 | 152 |
|
199 | | - let mut final_body = TokenStream::new(); |
200 | | - final_body.extend::<TokenStream>(assert_macros.parse().unwrap()); |
201 | | - final_body.extend(new_body); |
202 | | - final_body.extend::<TokenStream>(kunit_macros.parse().unwrap()); |
203 | | - |
204 | | - tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body))); |
| 153 | + test_cases.push(quote!( |
| 154 | + ::kernel::kunit::kunit_case(#test_cstr, #kunit_wrapper_fn_name) |
| 155 | + )); |
| 156 | + } |
205 | 157 |
|
206 | | - tokens.into_iter().collect() |
| 158 | + let num_tests_plus_1 = test_cases.len() + 1; |
| 159 | + processed_items.push(parse_quote! { |
| 160 | + static mut TEST_CASES: [::kernel::bindings::kunit_case; #num_tests_plus_1] = [ |
| 161 | + #(#test_cases,)* |
| 162 | + ::kernel::kunit::kunit_case_null(), |
| 163 | + ]; |
| 164 | + }); |
| 165 | + processed_items.push(parse_quote! { |
| 166 | + ::kernel::kunit_unsafe_test_suite!(#test_suite, TEST_CASES); |
| 167 | + }); |
| 168 | + |
| 169 | + module.content = Some((module_brace, processed_items)); |
| 170 | + Ok(module.to_token_stream()) |
207 | 171 | } |
0 commit comments