Skip to content

Commit debb27c

Browse files
committed
fuzz: Extract brute force to a separate module
1 parent 2e2b1c1 commit debb27c

2 files changed

Lines changed: 305 additions & 285 deletions

File tree

fuzz/src/exhaustive.rs

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
use std::io::Write;
2+
use std::mem;
3+
use std::num::NonZeroUsize;
4+
5+
use rustc_apfloat::FloatConvert;
6+
use rustc_apfloat::ieee::Double;
7+
use rustc_apfloat::ieee::Single;
8+
9+
use crate::Args;
10+
use crate::Commands;
11+
use crate::FloatRepr;
12+
use crate::apf_fuzz::FuzzOp;
13+
14+
pub fn run_exhaustive<F: FloatRepr>(cli_args: &Args) -> Result<(), std::num::NonZero<usize>>
15+
where
16+
F: Send + 'static,
17+
Single: FloatConvert<F::RustcApFloat>,
18+
Double: FloatConvert<F::RustcApFloat>,
19+
{
20+
let Some(Commands::Bruteforce {
21+
min_width,
22+
max_width,
23+
verbose,
24+
only_non_trivial_fma,
25+
}) = cli_args.command
26+
else {
27+
unreachable!("bruteforce({cli_args:?}): subcommand not `Commands::Bruteforce`");
28+
};
29+
30+
if !(min_width..=max_width).contains(&F::BIT_WIDTH) {
31+
return Ok(());
32+
}
33+
34+
// HACK(eddyb) there is a good chance C++ will also fail, so avoid the
35+
// (more fatal) C++ assertion failure, via `print_op_and_eval_outputs`.
36+
let cli_args_plus_ignore_cxx = Args {
37+
ignore_cxx: true,
38+
..cli_args.clone()
39+
};
40+
41+
let all_ops = (0..)
42+
.map(FuzzOp::from_tag)
43+
.take_while(|op| op.is_some())
44+
.map(|op| op.unwrap())
45+
.filter(move |op| {
46+
if only_non_trivial_fma {
47+
matches!(op, FuzzOp::MulAdd(..))
48+
} else {
49+
true
50+
}
51+
});
52+
53+
let op_to_combined_input_bits_range = move |op: FuzzOp<()>| {
54+
let mut total_bit_width = 0;
55+
op.map(|()| total_bit_width += F::BIT_WIDTH);
56+
57+
// HACK(eddyb) the highest `F::BIT_WIDTH` bits are the last input,
58+
// i.e. the addend for FMA (see also `Commands::Bruteforce` docs).
59+
let start_combined_input_bits = if only_non_trivial_fma {
60+
1 << (total_bit_width - F::BIT_WIDTH)
61+
} else {
62+
0
63+
};
64+
65+
start_combined_input_bits..u128::checked_shl(1, total_bit_width as u32).unwrap()
66+
};
67+
let op_to_exhaustive_cases = move |op: FuzzOp<()>| {
68+
op_to_combined_input_bits_range(op).map(move |i| -> FuzzOp<F> {
69+
let mut combined_input_bits = i;
70+
let op_with_inputs = op.map(|()| {
71+
let x = combined_input_bits & ((1 << F::BIT_WIDTH) - 1);
72+
combined_input_bits >>= F::BIT_WIDTH;
73+
F::from_bits_u128(x)
74+
});
75+
assert_eq!(combined_input_bits, 0);
76+
op_with_inputs
77+
})
78+
};
79+
80+
let num_total_cases = all_ops
81+
.clone()
82+
.map(|op| {
83+
let range = op_to_combined_input_bits_range(op);
84+
range.end.checked_sub(range.start).unwrap()
85+
})
86+
.try_fold(0, u128::checked_add)
87+
.unwrap();
88+
89+
let float_name = F::short_lowercase_name();
90+
println!("Exhaustively checking {num_total_cases} cases for {float_name}:");
91+
92+
// HACK(eddyb) show some indication of progress at least every few seconds,
93+
// but also don't show verbose progress as often, with fewer testcases.
94+
let num_dots = usize::try_from(num_total_cases >> 23)
95+
.unwrap_or(usize::MAX)
96+
.max(if verbose { 10 } else { 40 });
97+
let cases_per_dot =
98+
usize::try_from(num_total_cases / u128::try_from(num_dots).unwrap()).unwrap();
99+
100+
// Spawn worker threads and only report back from them once in a while
101+
// (in large batches of successes), or in case of any failure.
102+
let num_threads = std::thread::available_parallelism().unwrap();
103+
let successes_batch_size = (cases_per_dot / num_threads).next_power_of_two();
104+
105+
struct Update<T> {
106+
successes: usize,
107+
mismatch_or_panic: Option<(T, Option<Box<dyn std::any::Any + Send>>)>,
108+
}
109+
impl<T> Default for Update<T> {
110+
fn default() -> Self {
111+
Update {
112+
successes: 0,
113+
mismatch_or_panic: None,
114+
}
115+
}
116+
}
117+
let (updates_tx, updates_rx) = std::sync::mpsc::channel();
118+
119+
// HACK(eddyb) avoid reporting panics while iterating.
120+
std::panic::set_hook(Box::new(|_| {}));
121+
122+
let worker_threads: Vec<_> = (0..num_threads.get())
123+
.map(|thread_idx| {
124+
let cli_args = cli_args.clone();
125+
let updates_tx = updates_tx.clone();
126+
let cases_per_thread = all_ops
127+
.clone()
128+
.flat_map(op_to_exhaustive_cases)
129+
.skip(thread_idx)
130+
.step_by(num_threads.get());
131+
std::thread::spawn(move || {
132+
let mut update = Update::default();
133+
for op_with_inputs in cases_per_thread {
134+
// HACK(eddyb) there are still panics we need to account for,
135+
// e.g. https://github.com/llvm/llvm-project/issues/63895, and
136+
// even if the Rust code didn't panic, LLVM asserts would trip.
137+
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
138+
op_with_inputs.eval(&cli_args)
139+
})) {
140+
Ok(out) => {
141+
if out.all_match() {
142+
update.successes += 1;
143+
} else {
144+
update.mismatch_or_panic = Some((op_with_inputs, None));
145+
}
146+
}
147+
Err(panic) => {
148+
update.mismatch_or_panic = Some((op_with_inputs, Some(panic)));
149+
}
150+
}
151+
152+
if update.successes >= successes_batch_size
153+
|| update.mismatch_or_panic.is_some()
154+
{
155+
updates_tx.send(mem::take(&mut update)).unwrap();
156+
}
157+
}
158+
updates_tx.send(update).unwrap();
159+
})
160+
})
161+
.collect();
162+
163+
// HACK(eddyb) ensure that `Sender`s are only tied to active threads,
164+
// allowing the `for` loop below to exit, once all worker threads finish.
165+
drop(updates_tx);
166+
167+
let mut case_idx = 0;
168+
let mut current_dot_first_case_idx = 0;
169+
let mut last_mismatch_case_idx = None;
170+
let mut last_panic_case_idx = None;
171+
let mut all_mismatches = vec![];
172+
let mut all_panics = vec![];
173+
let mut verbose_failed_to_show_some_panics = false;
174+
for update in updates_rx {
175+
let Update {
176+
successes,
177+
mismatch_or_panic,
178+
} = update;
179+
let successes_and_failures = [
180+
Some(successes).filter(|&n| n > 0).map(Ok),
181+
mismatch_or_panic.map(Err),
182+
]
183+
.into_iter()
184+
.flatten();
185+
186+
for success_or_failure in successes_and_failures {
187+
match success_or_failure {
188+
Ok(successes) => case_idx += successes,
189+
190+
Err((op_with_inputs, None)) => {
191+
if verbose {
192+
op_with_inputs.print_op_and_eval_outputs(cli_args);
193+
}
194+
195+
last_mismatch_case_idx = Some(case_idx);
196+
all_mismatches.push(op_with_inputs);
197+
198+
case_idx += 1;
199+
}
200+
201+
Err((op_with_inputs, Some(panic))) => {
202+
if verbose {
203+
op_with_inputs.print_op_and_eval_outputs(&cli_args_plus_ignore_cxx);
204+
if let Ok(msg) = panic.downcast::<String>() {
205+
eprintln!("panicked with: {msg}");
206+
} else {
207+
verbose_failed_to_show_some_panics = true;
208+
}
209+
}
210+
211+
last_panic_case_idx = Some(case_idx);
212+
all_panics.push(op_with_inputs);
213+
214+
case_idx += 1;
215+
}
216+
}
217+
218+
loop {
219+
let next_dot_first_case_idx = current_dot_first_case_idx + cases_per_dot;
220+
if case_idx < next_dot_first_case_idx {
221+
break;
222+
}
223+
if verbose {
224+
println!(
225+
" {:3.1}% done ({case_idx} / {num_total_cases}), \
226+
found {} mismatches and {} panics",
227+
(case_idx as f64) / (num_total_cases as f64) * 100.0,
228+
all_mismatches.len(),
229+
all_panics.len()
230+
);
231+
} else {
232+
print!(
233+
"{}",
234+
if last_panic_case_idx.is_some_and(|i| i >= current_dot_first_case_idx) {
235+
'🕱'
236+
} else if last_mismatch_case_idx
237+
.is_some_and(|i| i >= current_dot_first_case_idx)
238+
{
239+
'≠'
240+
} else {
241+
'.'
242+
}
243+
);
244+
// HACK(eddyb) get around `stdout` line buffering.
245+
std::io::stdout().flush().unwrap();
246+
}
247+
current_dot_first_case_idx = next_dot_first_case_idx;
248+
}
249+
}
250+
}
251+
println!();
252+
253+
// HACK(eddyb) undo what we did just before spawning worker threads.
254+
let _ = std::panic::take_hook();
255+
256+
for worker_thread in worker_threads {
257+
worker_thread.join().unwrap();
258+
}
259+
260+
// HACK(eddyb) keep only one mismatch per `FuzzOp` variant.
261+
// FIXME(eddyb) consider sorting these (and panics?) due to parallelism.
262+
let num_mismatches = all_mismatches.len();
263+
let mut select_mismatches = all_mismatches;
264+
select_mismatches.dedup_by_key(|op_with_inputs| op_with_inputs.tag());
265+
266+
if num_mismatches > 0 {
267+
println!();
268+
println!(
269+
"⚠ found {num_mismatches} ({:.1}%) mismatches for {float_name}, showing {} of them:",
270+
(num_mismatches as f64) / (num_total_cases as f64) * 100.0,
271+
select_mismatches.len(),
272+
);
273+
for mismatch in select_mismatches {
274+
mismatch.print_op_and_eval_outputs(cli_args);
275+
}
276+
}
277+
278+
if !all_panics.is_empty() {
279+
println!();
280+
println!(
281+
"⚠ found {} panics for {float_name}, {}",
282+
all_panics.len(),
283+
if verbose && !verbose_failed_to_show_some_panics {
284+
"shown above"
285+
} else {
286+
"showing them (without trying C++):"
287+
},
288+
);
289+
if !verbose || verbose_failed_to_show_some_panics {
290+
for &panicking_case in &all_panics {
291+
panicking_case.print_op_and_eval_outputs(&cli_args_plus_ignore_cxx);
292+
}
293+
}
294+
}
295+
296+
if num_mismatches == 0 && all_panics.is_empty() {
297+
println!("✔️ all {num_total_cases} cases match");
298+
}
299+
println!();
300+
301+
NonZeroUsize::new(num_mismatches + all_panics.len()).map_or(Ok(()), Err)
302+
}

0 commit comments

Comments
 (0)