11use std:: io:: Write ;
22use std:: mem;
3+ use std:: num:: NonZero ;
34use std:: num:: NonZeroUsize ;
45
56use rustc_apfloat:: FloatConvert ;
7+ use rustc_apfloat:: Round ;
68use rustc_apfloat:: ieee:: Double ;
79use rustc_apfloat:: ieee:: Single ;
810
911use crate :: Args ;
1012use crate :: Commands ;
13+ use crate :: EvalCfg ;
1114use crate :: FloatRepr ;
12- use crate :: apf_fuzz:: FuzzOp ;
15+ use crate :: apf_fuzz:: Op ;
16+ use crate :: eval_all;
1317
14- pub fn run_exhaustive < F : FloatRepr > ( cli_args : & Args ) -> Result < ( ) , std :: num :: NonZero < usize > >
18+ pub fn run_exhaustive < F : FloatRepr > ( cli_args : & Args ) -> Result < ( ) , NonZero < usize > >
1519where
1620 F : Send + ' static ,
1721 Single : FloatConvert < F :: RustcApFloat > ,
@@ -38,21 +42,18 @@ where
3842 ..cli_args. clone ( )
3943 } ;
4044
41- let all_ops = ( 0 ..)
42- . map ( FuzzOp :: from_tag)
43- . take_while ( |op| op. is_some ( ) )
44- . map ( |op| op. unwrap ( ) )
45- . filter ( move |op| {
46- if only_non_trivial_fma {
47- matches ! ( op, FuzzOp :: MulAdd ( ..) )
48- } else {
49- true
50- }
51- } ) ;
45+ let all_ops = if only_non_trivial_fma {
46+ & [ Op :: MulAdd ]
47+ } else {
48+ Op :: ALL
49+ } ;
50+
51+ // This currently only tests round to nearest.
52+ let make_cfg =
53+ |op : Op , cli_args : & Args | EvalCfg :: new ( F :: KIND , op, Round :: NearestTiesToEven , cli_args) ;
5254
53- let op_to_combined_input_bits_range = move |op : FuzzOp < ( ) > | {
54- let mut total_bit_width = 0 ;
55- op. map ( |( ) | total_bit_width += F :: BIT_WIDTH ) ;
55+ let op_to_combined_input_bits_range = move |op : Op | {
56+ let total_bit_width = F :: BIT_WIDTH * ( op. airity ( ) as usize ) ;
5657
5758 // HACK(eddyb) the highest `F::BIT_WIDTH` bits are the last input,
5859 // i.e. the addend for FMA (see also `Commands::Bruteforce` docs).
@@ -62,29 +63,30 @@ where
6263 0
6364 } ;
6465
65- start_combined_input_bits..u128 :: checked_shl ( 1 , total_bit_width as u32 ) . unwrap ( )
66+ start_combined_input_bits..1_u128 . strict_shl ( total_bit_width as u32 )
6667 } ;
67- let op_to_exhaustive_cases = move |op : FuzzOp < ( ) > | {
68- op_to_combined_input_bits_range ( op) . map ( move |i| -> FuzzOp < F > {
69- let mut combined_input_bits = i;
70- let op_with_inputs = op. map ( |( ) | {
71- let x = combined_input_bits & ( ( 1 << F :: BIT_WIDTH ) - 1 ) ;
72- combined_input_bits >>= F :: BIT_WIDTH ;
73- F :: from_bits_u128 ( x)
74- } ) ;
75- assert_eq ! ( combined_input_bits, 0 ) ;
76- op_with_inputs
68+ let op_to_exhaustive_cases = move |op : Op | {
69+ op_to_combined_input_bits_range ( op) . map ( move |i| -> ( F , F , F ) {
70+ let mask = ( 1 << F :: BIT_WIDTH ) - 1 ;
71+ let a = ( i >> ( 0 * F :: BIT_WIDTH ) ) & mask;
72+ let b = ( i >> ( 1 * F :: BIT_WIDTH ) ) & mask;
73+ let c = ( i >> ( 2 * F :: BIT_WIDTH ) ) & mask;
74+ assert_eq ! ( i >> ( 3 * F :: BIT_WIDTH ) , 0 ) ;
75+ (
76+ F :: from_bits_u128 ( a) ,
77+ F :: from_bits_u128 ( b) ,
78+ F :: from_bits_u128 ( c) ,
79+ )
7780 } )
7881 } ;
7982
8083 let num_total_cases = all_ops
81- . clone ( )
84+ . iter ( )
8285 . map ( |op| {
83- let range = op_to_combined_input_bits_range ( op) ;
84- range. end . checked_sub ( range. start ) . unwrap ( )
86+ let range = op_to_combined_input_bits_range ( * op) ;
87+ range. end . strict_sub ( range. start )
8588 } )
86- . try_fold ( 0 , u128:: checked_add)
87- . unwrap ( ) ;
89+ . fold ( 0 , u128:: strict_add) ;
8890
8991 let float_name = F :: short_lowercase_name ( ) ;
9092 println ! ( "Exhaustively checking {num_total_cases} cases for {float_name}:" ) ;
@@ -124,28 +126,29 @@ where
124126 let cli_args = cli_args. clone ( ) ;
125127 let updates_tx = updates_tx. clone ( ) ;
126128 let cases_per_thread = all_ops
127- . clone ( )
128- . flat_map ( op_to_exhaustive_cases)
129+ . iter ( )
130+ . flat_map ( move |op| op_to_exhaustive_cases ( * op ) . map ( | ( a , b , c ) | ( * op , a , b , c ) ) )
129131 . skip ( thread_idx)
130132 . step_by ( num_threads. get ( ) ) ;
133+
131134 std:: thread:: spawn ( move || {
132135 let mut update = Update :: default ( ) ;
133- for op_with_inputs in cases_per_thread {
134- // HACK(eddyb) there are still panics we need to account for,
135- // e.g. https://github.com/llvm/llvm-project/issues/63895, and
136- // even if the Rust code didn't panic, LLVM asserts would trip.
137- match std :: panic :: catch_unwind ( std :: panic :: AssertUnwindSafe ( || {
138- op_with_inputs . eval ( & cli_args )
139- } ) ) {
136+
137+ for ( op , a , b , c ) in cases_per_thread {
138+ let cfg = make_cfg ( op , & cli_args ) ;
139+ let res = std :: panic:: catch_unwind ( std :: panic :: AssertUnwindSafe ( || {
140+ eval_all ( & cfg , a , b , c )
141+ } ) ) ;
142+ match res {
140143 Ok ( out) => {
141- if out. all_match ( ) {
144+ if out. check_all ( & cfg , a , b , c , false ) . is_ok ( ) {
142145 update. successes += 1 ;
143146 } else {
144- update. mismatch_or_panic = Some ( ( op_with_inputs , None ) ) ;
147+ update. mismatch_or_panic = Some ( ( ( op , a , b , c ) , None ) ) ;
145148 }
146149 }
147150 Err ( panic) => {
148- update. mismatch_or_panic = Some ( ( op_with_inputs , Some ( panic) ) ) ;
151+ update. mismatch_or_panic = Some ( ( ( op , a , b , c ) , Some ( panic) ) ) ;
149152 }
150153 }
151154
@@ -189,7 +192,9 @@ where
189192
190193 Err ( ( op_with_inputs, None ) ) => {
191194 if verbose {
192- op_with_inputs. print_op_and_eval_outputs ( cli_args) ;
195+ let ( op, a, b, c) = op_with_inputs;
196+ let cfg = make_cfg ( op, & cli_args) ;
197+ let _ = eval_all ( & cfg, a, b, c) . check_all ( & cfg, a, b, c, true ) ;
193198 }
194199
195200 last_mismatch_case_idx = Some ( case_idx) ;
@@ -200,7 +205,10 @@ where
200205
201206 Err ( ( op_with_inputs, Some ( panic) ) ) => {
202207 if verbose {
203- op_with_inputs. print_op_and_eval_outputs ( & cli_args_plus_ignore_cxx) ;
208+ let ( op, a, b, c) = op_with_inputs;
209+ let cfg = make_cfg ( op, & cli_args_plus_ignore_cxx) ;
210+ let _ = eval_all ( & cfg, a, b, c) . check_all ( & cfg, a, b, c, true ) ;
211+
204212 if let Ok ( msg) = panic. downcast :: < String > ( ) {
205213 eprintln ! ( "panicked with: {msg}" ) ;
206214 } else {
@@ -261,7 +269,7 @@ where
261269 // FIXME(eddyb) consider sorting these (and panics?) due to parallelism.
262270 let num_mismatches = all_mismatches. len ( ) ;
263271 let mut select_mismatches = all_mismatches;
264- select_mismatches. dedup_by_key ( |op_with_inputs| op_with_inputs . tag ( ) ) ;
272+ select_mismatches. dedup_by_key ( |( op , _ , _ , _ ) | * op ) ;
265273
266274 if num_mismatches > 0 {
267275 println ! ( ) ;
@@ -271,7 +279,9 @@ where
271279 select_mismatches. len( ) ,
272280 ) ;
273281 for mismatch in select_mismatches {
274- mismatch. print_op_and_eval_outputs ( cli_args) ;
282+ let ( op, a, b, c) = mismatch;
283+ let cfg = make_cfg ( op, & cli_args) ;
284+ let _ = eval_all ( & cfg, a, b, c) . check_all ( & cfg, a, b, c, true ) ;
275285 }
276286 }
277287
@@ -288,7 +298,9 @@ where
288298 ) ;
289299 if !verbose || verbose_failed_to_show_some_panics {
290300 for & panicking_case in & all_panics {
291- panicking_case. print_op_and_eval_outputs ( & cli_args_plus_ignore_cxx) ;
301+ let ( op, a, b, c) = panicking_case;
302+ let cfg = make_cfg ( op, & cli_args_plus_ignore_cxx) ;
303+ let _ = eval_all ( & cfg, a, b, c) . check_all ( & cfg, a, b, c, true ) ;
292304 }
293305 }
294306 }
0 commit comments