Skip to content

Commit 04a42a6

Browse files
nbdd0121BennoLossin
authored andcommitted
rust: macros: convert #[kunit_tests] macro to use syn
Make use of `syn` to parse the module structurally and thus improve the robustness of parsing. String interpolation is avoided by generating tokens directly using `quote!`. Reviewed-by: Tamir Duberstein <tamird@gmail.com> Signed-off-by: Gary Guo <gary@garyguo.net>
1 parent 8ccb62c commit 04a42a6

File tree

2 files changed

+123
-157
lines changed

2 files changed

+123
-157
lines changed

rust/macros/kunit.rs

Lines changed: 119 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -4,81 +4,50 @@
44
//!
55
//! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com>
66
7-
use std::collections::HashMap;
8-
use std::fmt::Write;
9-
10-
use proc_macro2::{Delimiter, Group, TokenStream, TokenTree};
11-
12-
pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
13-
let attr = attr.to_string();
14-
15-
if attr.is_empty() {
16-
panic!("Missing test name in `#[kunit_tests(test_name)]` macro")
17-
}
18-
19-
if attr.len() > 255 {
20-
panic!("The test suite name `{attr}` exceeds the maximum length of 255 bytes")
7+
use std::ffi::CString;
8+
9+
use proc_macro2::TokenStream;
10+
use quote::{
11+
format_ident,
12+
quote,
13+
ToTokens, //
14+
};
15+
use syn::{
16+
parse_quote,
17+
Error,
18+
Ident,
19+
Item,
20+
ItemMod,
21+
LitCStr,
22+
Result, //
23+
};
24+
25+
pub(crate) fn kunit_tests(test_suite: Ident, mut module: ItemMod) -> Result<TokenStream> {
26+
if test_suite.to_string().len() > 255 {
27+
return Err(Error::new_spanned(
28+
test_suite,
29+
"test suite names cannot exceed the maximum length of 255 bytes",
30+
));
2131
}
2232

23-
let mut tokens: Vec<_> = ts.into_iter().collect();
24-
25-
// Scan for the `mod` keyword.
26-
tokens
27-
.iter()
28-
.find_map(|token| match token {
29-
TokenTree::Ident(ident) => match ident.to_string().as_str() {
30-
"mod" => Some(true),
31-
_ => None,
32-
},
33-
_ => None,
34-
})
35-
.expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules");
36-
37-
// Retrieve the main body. The main body should be the last token tree.
38-
let body = match tokens.pop() {
39-
Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group,
40-
_ => panic!("Cannot locate main body of module"),
33+
// We cannot handle modules that defer to another file (e.g. `mod foo;`).
34+
let Some((module_brace, module_items)) = module.content.take() else {
35+
Err(Error::new_spanned(
36+
module,
37+
"`#[kunit_tests(test_name)]` attribute should only be applied to inline modules",
38+
))?
4139
};
4240

43-
// Get the functions set as tests. Search for `[test]` -> `fn`.
44-
let mut body_it = body.stream().into_iter();
45-
let mut tests = Vec::new();
46-
let mut attributes: HashMap<String, TokenStream> = HashMap::new();
47-
while let Some(token) = body_it.next() {
48-
match token {
49-
TokenTree::Punct(ref p) if p.as_char() == '#' => match body_it.next() {
50-
Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => {
51-
if let Some(TokenTree::Ident(name)) = g.stream().into_iter().next() {
52-
// Collect attributes because we need to find which are tests. We also
53-
// need to copy `cfg` attributes so tests can be conditionally enabled.
54-
attributes
55-
.entry(name.to_string())
56-
.or_default()
57-
.extend([token, TokenTree::Group(g)]);
58-
}
59-
continue;
60-
}
61-
_ => (),
62-
},
63-
TokenTree::Ident(i) if i == "fn" && attributes.contains_key("test") => {
64-
if let Some(TokenTree::Ident(test_name)) = body_it.next() {
65-
tests.push((test_name, attributes.remove("cfg").unwrap_or_default()))
66-
}
67-
}
68-
69-
_ => (),
70-
}
71-
attributes.clear();
72-
}
41+
// Make the entire module gated behind `CONFIG_KUNIT`.
42+
module
43+
.attrs
44+
.insert(0, parse_quote!(#[cfg(CONFIG_KUNIT="y")]));
7345

74-
// Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration.
75-
let config_kunit = "#[cfg(CONFIG_KUNIT=\"y\")]".to_owned().parse().unwrap();
76-
tokens.insert(
77-
0,
78-
TokenTree::Group(Group::new(Delimiter::None, config_kunit)),
79-
);
46+
let mut processed_items = Vec::new();
47+
let mut test_cases = Vec::new();
8048

8149
// Generate the test KUnit test suite and a test case for each `#[test]`.
50+
//
8251
// The code generated for the following test module:
8352
//
8453
// ```
@@ -110,98 +79,93 @@ pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
11079
//
11180
// ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES);
11281
// ```
113-
let mut kunit_macros = "".to_owned();
114-
let mut test_cases = "".to_owned();
115-
let mut assert_macros = "".to_owned();
116-
let path = crate::helpers::file();
117-
let num_tests = tests.len();
118-
for (test, cfg_attr) in tests {
119-
let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}");
120-
// Append any `cfg` attributes the user might have written on their tests so we don't
121-
// attempt to call them when they are `cfg`'d out. An extra `use` is used here to reduce
122-
// the length of the assert message.
123-
let kunit_wrapper = format!(
124-
r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit)
125-
{{
126-
(*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED;
127-
{cfg_attr} {{
128-
(*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS;
129-
use ::kernel::kunit::is_test_result_ok;
130-
assert!(is_test_result_ok({test}()));
82+
//
83+
// Non-function items (e.g. imports) are preserved.
84+
for item in module_items {
85+
let Item::Fn(mut f) = item else {
86+
processed_items.push(item);
87+
continue;
88+
};
89+
90+
// TODO: Replace below with `extract_if` when MSRV is bumped above 1.85.
91+
let before_len = f.attrs.len();
92+
f.attrs.retain(|attr| !attr.path().is_ident("test"));
93+
if f.attrs.len() == before_len {
94+
processed_items.push(Item::Fn(f));
95+
continue;
96+
}
97+
98+
let test = f.sig.ident.clone();
99+
100+
// Retrieve `#[cfg]` applied on the function which needs to be present on derived items too.
101+
let cfg_attrs: Vec<_> = f
102+
.attrs
103+
.iter()
104+
.filter(|attr| attr.path().is_ident("cfg"))
105+
.cloned()
106+
.collect();
107+
108+
// Before the test, override usual `assert!` and `assert_eq!` macros with ones that call
109+
// KUnit instead.
110+
let test_str = test.to_string();
111+
let path = crate::helpers::file();
112+
processed_items.push(parse_quote! {
113+
#[allow(unused)]
114+
macro_rules! assert {
115+
($cond:expr $(,)?) => {{
116+
kernel::kunit_assert!(#test_str, #path, 0, $cond);
117+
}}
118+
}
119+
});
120+
processed_items.push(parse_quote! {
121+
#[allow(unused)]
122+
macro_rules! assert_eq {
123+
($left:expr, $right:expr $(,)?) => {{
124+
kernel::kunit_assert_eq!(#test_str, #path, 0, $left, $right);
131125
}}
132-
}}"#,
126+
}
127+
});
128+
129+
// Add back the test item.
130+
processed_items.push(Item::Fn(f));
131+
132+
let kunit_wrapper_fn_name = format_ident!("kunit_rust_wrapper_{test}");
133+
let test_cstr = LitCStr::new(
134+
&CString::new(test_str.as_str()).expect("identifier cannot contain NUL"),
135+
test.span(),
133136
);
134-
writeln!(kunit_macros, "{kunit_wrapper}").unwrap();
135-
writeln!(
136-
test_cases,
137-
" ::kernel::kunit::kunit_case(::kernel::c_str!(\"{test}\"), {kunit_wrapper_fn_name}),"
138-
)
139-
.unwrap();
140-
writeln!(
141-
assert_macros,
142-
r#"
143-
/// Overrides the usual [`assert!`] macro with one that calls KUnit instead.
144-
#[allow(unused)]
145-
macro_rules! assert {{
146-
($cond:expr $(,)?) => {{{{
147-
kernel::kunit_assert!("{test}", "{path}", 0, $cond);
148-
}}}}
149-
}}
150-
151-
/// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead.
152-
#[allow(unused)]
153-
macro_rules! assert_eq {{
154-
($left:expr, $right:expr $(,)?) => {{{{
155-
kernel::kunit_assert_eq!("{test}", "{path}", 0, $left, $right);
156-
}}}}
157-
}}
158-
"#
159-
)
160-
.unwrap();
161-
}
137+
processed_items.push(parse_quote! {
138+
unsafe extern "C" fn #kunit_wrapper_fn_name(_test: *mut ::kernel::bindings::kunit) {
139+
(*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED;
162140

163-
writeln!(kunit_macros).unwrap();
164-
writeln!(
165-
kunit_macros,
166-
"static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = [\n{test_cases} ::kernel::kunit::kunit_case_null(),\n];",
167-
num_tests + 1
168-
)
169-
.unwrap();
170-
171-
writeln!(
172-
kunit_macros,
173-
"::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);"
174-
)
175-
.unwrap();
176-
177-
// Remove the `#[test]` macros.
178-
// We do this at a token level, in order to preserve span information.
179-
let mut new_body = vec![];
180-
let mut body_it = body.stream().into_iter();
181-
182-
while let Some(token) = body_it.next() {
183-
match token {
184-
TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() {
185-
Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (),
186-
Some(next) => {
187-
new_body.extend([token, next]);
188-
}
189-
_ => {
190-
new_body.push(token);
141+
// Append any `cfg` attributes the user might have written on their tests so we
142+
// don't attempt to call them when they are `cfg`'d out. An extra `use` is used
143+
// here to reduce the length of the assert message.
144+
#(#cfg_attrs)*
145+
{
146+
(*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS;
147+
use ::kernel::kunit::is_test_result_ok;
148+
assert!(is_test_result_ok(#test()));
191149
}
192-
},
193-
_ => {
194-
new_body.push(token);
195150
}
196-
}
197-
}
151+
});
198152

199-
let mut final_body = TokenStream::new();
200-
final_body.extend::<TokenStream>(assert_macros.parse().unwrap());
201-
final_body.extend(new_body);
202-
final_body.extend::<TokenStream>(kunit_macros.parse().unwrap());
203-
204-
tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body)));
153+
test_cases.push(quote!(
154+
::kernel::kunit::kunit_case(#test_cstr, #kunit_wrapper_fn_name)
155+
));
156+
}
205157

206-
tokens.into_iter().collect()
158+
let num_tests_plus_1 = test_cases.len() + 1;
159+
processed_items.push(parse_quote! {
160+
static mut TEST_CASES: [::kernel::bindings::kunit_case; #num_tests_plus_1] = [
161+
#(#test_cases,)*
162+
::kernel::kunit::kunit_case_null(),
163+
];
164+
});
165+
processed_items.push(parse_quote! {
166+
::kernel::kunit_unsafe_test_suite!(#test_suite, TEST_CASES);
167+
});
168+
169+
module.content = Some((module_brace, processed_items));
170+
Ok(module.to_token_stream())
207171
}

rust/macros/lib.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,8 @@ pub fn paste(input: TokenStream) -> TokenStream {
481481
/// }
482482
/// ```
483483
#[proc_macro_attribute]
484-
pub fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
485-
kunit::kunit_tests(attr.into(), ts.into()).into()
484+
pub fn kunit_tests(attr: TokenStream, input: TokenStream) -> TokenStream {
485+
kunit::kunit_tests(parse_macro_input!(attr), parse_macro_input!(input))
486+
.unwrap_or_else(|e| e.into_compile_error())
487+
.into()
486488
}

0 commit comments

Comments
 (0)