|
| 1 | +use std::io::Write; |
| 2 | +use std::mem; |
| 3 | +use std::num::NonZeroUsize; |
| 4 | + |
| 5 | +use rustc_apfloat::FloatConvert; |
| 6 | +use rustc_apfloat::ieee::Double; |
| 7 | +use rustc_apfloat::ieee::Single; |
| 8 | + |
| 9 | +use crate::Args; |
| 10 | +use crate::Commands; |
| 11 | +use crate::FloatRepr; |
| 12 | +use crate::apf_fuzz::FuzzOp; |
| 13 | + |
| 14 | +pub fn run_exhaustive<F: FloatRepr>(cli_args: &Args) -> Result<(), std::num::NonZero<usize>> |
| 15 | +where |
| 16 | + F: Send + 'static, |
| 17 | + Single: FloatConvert<F::RustcApFloat>, |
| 18 | + Double: FloatConvert<F::RustcApFloat>, |
| 19 | +{ |
| 20 | + let Some(Commands::Bruteforce { |
| 21 | + min_width, |
| 22 | + max_width, |
| 23 | + verbose, |
| 24 | + only_non_trivial_fma, |
| 25 | + }) = cli_args.command |
| 26 | + else { |
| 27 | + unreachable!("bruteforce({cli_args:?}): subcommand not `Commands::Bruteforce`"); |
| 28 | + }; |
| 29 | + |
| 30 | + if !(min_width..=max_width).contains(&F::BIT_WIDTH) { |
| 31 | + return Ok(()); |
| 32 | + } |
| 33 | + |
| 34 | + // HACK(eddyb) there is a good chance C++ will also fail, so avoid the |
| 35 | + // (more fatal) C++ assertion failure, via `print_op_and_eval_outputs`. |
| 36 | + let cli_args_plus_ignore_cxx = Args { |
| 37 | + ignore_cxx: true, |
| 38 | + ..cli_args.clone() |
| 39 | + }; |
| 40 | + |
| 41 | + let all_ops = (0..) |
| 42 | + .map(FuzzOp::from_tag) |
| 43 | + .take_while(|op| op.is_some()) |
| 44 | + .map(|op| op.unwrap()) |
| 45 | + .filter(move |op| { |
| 46 | + if only_non_trivial_fma { |
| 47 | + matches!(op, FuzzOp::MulAdd(..)) |
| 48 | + } else { |
| 49 | + true |
| 50 | + } |
| 51 | + }); |
| 52 | + |
| 53 | + let op_to_combined_input_bits_range = move |op: FuzzOp<()>| { |
| 54 | + let mut total_bit_width = 0; |
| 55 | + op.map(|()| total_bit_width += F::BIT_WIDTH); |
| 56 | + |
| 57 | + // HACK(eddyb) the highest `F::BIT_WIDTH` bits are the last input, |
| 58 | + // i.e. the addend for FMA (see also `Commands::Bruteforce` docs). |
| 59 | + let start_combined_input_bits = if only_non_trivial_fma { |
| 60 | + 1 << (total_bit_width - F::BIT_WIDTH) |
| 61 | + } else { |
| 62 | + 0 |
| 63 | + }; |
| 64 | + |
| 65 | + start_combined_input_bits..u128::checked_shl(1, total_bit_width as u32).unwrap() |
| 66 | + }; |
| 67 | + let op_to_exhaustive_cases = move |op: FuzzOp<()>| { |
| 68 | + op_to_combined_input_bits_range(op).map(move |i| -> FuzzOp<F> { |
| 69 | + let mut combined_input_bits = i; |
| 70 | + let op_with_inputs = op.map(|()| { |
| 71 | + let x = combined_input_bits & ((1 << F::BIT_WIDTH) - 1); |
| 72 | + combined_input_bits >>= F::BIT_WIDTH; |
| 73 | + F::from_bits_u128(x) |
| 74 | + }); |
| 75 | + assert_eq!(combined_input_bits, 0); |
| 76 | + op_with_inputs |
| 77 | + }) |
| 78 | + }; |
| 79 | + |
| 80 | + let num_total_cases = all_ops |
| 81 | + .clone() |
| 82 | + .map(|op| { |
| 83 | + let range = op_to_combined_input_bits_range(op); |
| 84 | + range.end.checked_sub(range.start).unwrap() |
| 85 | + }) |
| 86 | + .try_fold(0, u128::checked_add) |
| 87 | + .unwrap(); |
| 88 | + |
| 89 | + let float_name = F::short_lowercase_name(); |
| 90 | + println!("Exhaustively checking {num_total_cases} cases for {float_name}:"); |
| 91 | + |
| 92 | + // HACK(eddyb) show some indication of progress at least every few seconds, |
| 93 | + // but also don't show verbose progress as often, with fewer testcases. |
| 94 | + let num_dots = usize::try_from(num_total_cases >> 23) |
| 95 | + .unwrap_or(usize::MAX) |
| 96 | + .max(if verbose { 10 } else { 40 }); |
| 97 | + let cases_per_dot = |
| 98 | + usize::try_from(num_total_cases / u128::try_from(num_dots).unwrap()).unwrap(); |
| 99 | + |
| 100 | + // Spawn worker threads and only report back from them once in a while |
| 101 | + // (in large batches of successes), or in case of any failure. |
| 102 | + let num_threads = std::thread::available_parallelism().unwrap(); |
| 103 | + let successes_batch_size = (cases_per_dot / num_threads).next_power_of_two(); |
| 104 | + |
| 105 | + struct Update<T> { |
| 106 | + successes: usize, |
| 107 | + mismatch_or_panic: Option<(T, Option<Box<dyn std::any::Any + Send>>)>, |
| 108 | + } |
| 109 | + impl<T> Default for Update<T> { |
| 110 | + fn default() -> Self { |
| 111 | + Update { |
| 112 | + successes: 0, |
| 113 | + mismatch_or_panic: None, |
| 114 | + } |
| 115 | + } |
| 116 | + } |
| 117 | + let (updates_tx, updates_rx) = std::sync::mpsc::channel(); |
| 118 | + |
| 119 | + // HACK(eddyb) avoid reporting panics while iterating. |
| 120 | + std::panic::set_hook(Box::new(|_| {})); |
| 121 | + |
| 122 | + let worker_threads: Vec<_> = (0..num_threads.get()) |
| 123 | + .map(|thread_idx| { |
| 124 | + let cli_args = cli_args.clone(); |
| 125 | + let updates_tx = updates_tx.clone(); |
| 126 | + let cases_per_thread = all_ops |
| 127 | + .clone() |
| 128 | + .flat_map(op_to_exhaustive_cases) |
| 129 | + .skip(thread_idx) |
| 130 | + .step_by(num_threads.get()); |
| 131 | + std::thread::spawn(move || { |
| 132 | + let mut update = Update::default(); |
| 133 | + for op_with_inputs in cases_per_thread { |
| 134 | + // HACK(eddyb) there are still panics we need to account for, |
| 135 | + // e.g. https://github.com/llvm/llvm-project/issues/63895, and |
| 136 | + // even if the Rust code didn't panic, LLVM asserts would trip. |
| 137 | + match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { |
| 138 | + op_with_inputs.eval(&cli_args) |
| 139 | + })) { |
| 140 | + Ok(out) => { |
| 141 | + if out.all_match() { |
| 142 | + update.successes += 1; |
| 143 | + } else { |
| 144 | + update.mismatch_or_panic = Some((op_with_inputs, None)); |
| 145 | + } |
| 146 | + } |
| 147 | + Err(panic) => { |
| 148 | + update.mismatch_or_panic = Some((op_with_inputs, Some(panic))); |
| 149 | + } |
| 150 | + } |
| 151 | + |
| 152 | + if update.successes >= successes_batch_size |
| 153 | + || update.mismatch_or_panic.is_some() |
| 154 | + { |
| 155 | + updates_tx.send(mem::take(&mut update)).unwrap(); |
| 156 | + } |
| 157 | + } |
| 158 | + updates_tx.send(update).unwrap(); |
| 159 | + }) |
| 160 | + }) |
| 161 | + .collect(); |
| 162 | + |
| 163 | + // HACK(eddyb) ensure that `Sender`s are only tied to active threads, |
| 164 | + // allowing the `for` loop below to exit, once all worker threads finish. |
| 165 | + drop(updates_tx); |
| 166 | + |
| 167 | + let mut case_idx = 0; |
| 168 | + let mut current_dot_first_case_idx = 0; |
| 169 | + let mut last_mismatch_case_idx = None; |
| 170 | + let mut last_panic_case_idx = None; |
| 171 | + let mut all_mismatches = vec![]; |
| 172 | + let mut all_panics = vec![]; |
| 173 | + let mut verbose_failed_to_show_some_panics = false; |
| 174 | + for update in updates_rx { |
| 175 | + let Update { |
| 176 | + successes, |
| 177 | + mismatch_or_panic, |
| 178 | + } = update; |
| 179 | + let successes_and_failures = [ |
| 180 | + Some(successes).filter(|&n| n > 0).map(Ok), |
| 181 | + mismatch_or_panic.map(Err), |
| 182 | + ] |
| 183 | + .into_iter() |
| 184 | + .flatten(); |
| 185 | + |
| 186 | + for success_or_failure in successes_and_failures { |
| 187 | + match success_or_failure { |
| 188 | + Ok(successes) => case_idx += successes, |
| 189 | + |
| 190 | + Err((op_with_inputs, None)) => { |
| 191 | + if verbose { |
| 192 | + op_with_inputs.print_op_and_eval_outputs(cli_args); |
| 193 | + } |
| 194 | + |
| 195 | + last_mismatch_case_idx = Some(case_idx); |
| 196 | + all_mismatches.push(op_with_inputs); |
| 197 | + |
| 198 | + case_idx += 1; |
| 199 | + } |
| 200 | + |
| 201 | + Err((op_with_inputs, Some(panic))) => { |
| 202 | + if verbose { |
| 203 | + op_with_inputs.print_op_and_eval_outputs(&cli_args_plus_ignore_cxx); |
| 204 | + if let Ok(msg) = panic.downcast::<String>() { |
| 205 | + eprintln!("panicked with: {msg}"); |
| 206 | + } else { |
| 207 | + verbose_failed_to_show_some_panics = true; |
| 208 | + } |
| 209 | + } |
| 210 | + |
| 211 | + last_panic_case_idx = Some(case_idx); |
| 212 | + all_panics.push(op_with_inputs); |
| 213 | + |
| 214 | + case_idx += 1; |
| 215 | + } |
| 216 | + } |
| 217 | + |
| 218 | + loop { |
| 219 | + let next_dot_first_case_idx = current_dot_first_case_idx + cases_per_dot; |
| 220 | + if case_idx < next_dot_first_case_idx { |
| 221 | + break; |
| 222 | + } |
| 223 | + if verbose { |
| 224 | + println!( |
| 225 | + " {:3.1}% done ({case_idx} / {num_total_cases}), \ |
| 226 | + found {} mismatches and {} panics", |
| 227 | + (case_idx as f64) / (num_total_cases as f64) * 100.0, |
| 228 | + all_mismatches.len(), |
| 229 | + all_panics.len() |
| 230 | + ); |
| 231 | + } else { |
| 232 | + print!( |
| 233 | + "{}", |
| 234 | + if last_panic_case_idx.is_some_and(|i| i >= current_dot_first_case_idx) { |
| 235 | + '🕱' |
| 236 | + } else if last_mismatch_case_idx |
| 237 | + .is_some_and(|i| i >= current_dot_first_case_idx) |
| 238 | + { |
| 239 | + '≠' |
| 240 | + } else { |
| 241 | + '.' |
| 242 | + } |
| 243 | + ); |
| 244 | + // HACK(eddyb) get around `stdout` line buffering. |
| 245 | + std::io::stdout().flush().unwrap(); |
| 246 | + } |
| 247 | + current_dot_first_case_idx = next_dot_first_case_idx; |
| 248 | + } |
| 249 | + } |
| 250 | + } |
| 251 | + println!(); |
| 252 | + |
| 253 | + // HACK(eddyb) undo what we did just before spawning worker threads. |
| 254 | + let _ = std::panic::take_hook(); |
| 255 | + |
| 256 | + for worker_thread in worker_threads { |
| 257 | + worker_thread.join().unwrap(); |
| 258 | + } |
| 259 | + |
| 260 | + // HACK(eddyb) keep only one mismatch per `FuzzOp` variant. |
| 261 | + // FIXME(eddyb) consider sorting these (and panics?) due to parallelism. |
| 262 | + let num_mismatches = all_mismatches.len(); |
| 263 | + let mut select_mismatches = all_mismatches; |
| 264 | + select_mismatches.dedup_by_key(|op_with_inputs| op_with_inputs.tag()); |
| 265 | + |
| 266 | + if num_mismatches > 0 { |
| 267 | + println!(); |
| 268 | + println!( |
| 269 | + "⚠ found {num_mismatches} ({:.1}%) mismatches for {float_name}, showing {} of them:", |
| 270 | + (num_mismatches as f64) / (num_total_cases as f64) * 100.0, |
| 271 | + select_mismatches.len(), |
| 272 | + ); |
| 273 | + for mismatch in select_mismatches { |
| 274 | + mismatch.print_op_and_eval_outputs(cli_args); |
| 275 | + } |
| 276 | + } |
| 277 | + |
| 278 | + if !all_panics.is_empty() { |
| 279 | + println!(); |
| 280 | + println!( |
| 281 | + "⚠ found {} panics for {float_name}, {}", |
| 282 | + all_panics.len(), |
| 283 | + if verbose && !verbose_failed_to_show_some_panics { |
| 284 | + "shown above" |
| 285 | + } else { |
| 286 | + "showing them (without trying C++):" |
| 287 | + }, |
| 288 | + ); |
| 289 | + if !verbose || verbose_failed_to_show_some_panics { |
| 290 | + for &panicking_case in &all_panics { |
| 291 | + panicking_case.print_op_and_eval_outputs(&cli_args_plus_ignore_cxx); |
| 292 | + } |
| 293 | + } |
| 294 | + } |
| 295 | + |
| 296 | + if num_mismatches == 0 && all_panics.is_empty() { |
| 297 | + println!("✔️ all {num_total_cases} cases match"); |
| 298 | + } |
| 299 | + println!(); |
| 300 | + |
| 301 | + NonZeroUsize::new(num_mismatches + all_panics.len()).map_or(Ok(()), Err) |
| 302 | +} |
0 commit comments