@@ -414,117 +414,9 @@ fn compute_residual_array1<E: Scalar, D: Data<Elem = E>>(
414414
415415#[ cfg( test) ]
416416mod tests {
417- use super :: * ;
417+ use crate :: { error :: LinalgError , * } ;
418418 use approx:: AbsDiffEq ;
419- use ndarray:: { ArcArray1 , ArcArray2 , Array1 , Array2 , CowArray } ;
420- use num_complex:: Complex ;
421-
422- //
423- // Test cases taken from the scipy test suite for the scipy lstsq function
424- // https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/test_basic.py
425- //
426- #[ test]
427- fn scipy_test_simple_exact ( ) {
428- let a = array ! [ [ 1. , 20. ] , [ -30. , 4. ] ] ;
429- let bs = vec ! [
430- array![ [ 1. , 0. ] , [ 0. , 1. ] ] ,
431- array![ [ 1. ] , [ 0. ] ] ,
432- array![ [ 2. , 1. ] , [ -30. , 4. ] ] ,
433- ] ;
434- for b in & bs {
435- let res = a. least_squares ( b) . unwrap ( ) ;
436- assert_eq ! ( res. rank, 2 ) ;
437- let b_hat = a. dot ( & res. solution ) ;
438- let rssq = ( b - & b_hat) . mapv ( |x| x. powi ( 2 ) ) . sum_axis ( Axis ( 0 ) ) ;
439- assert ! ( res
440- . residual_sum_of_squares
441- . unwrap( )
442- . abs_diff_eq( & rssq, 1e-12 ) ) ;
443- assert ! ( b_hat. abs_diff_eq( & b, 1e-12 ) ) ;
444- }
445- }
446-
447- #[ test]
448- fn scipy_test_simple_overdetermined ( ) {
449- let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
450- let b: Array1 < f64 > = array ! [ 1. , 2. , 3. ] ;
451- let res = a. least_squares ( & b) . unwrap ( ) ;
452- assert_eq ! ( res. rank, 2 ) ;
453- let b_hat = a. dot ( & res. solution ) ;
454- let rssq = ( & b - & b_hat) . mapv ( |x| x. powi ( 2 ) ) . sum ( ) ;
455- assert ! ( res. residual_sum_of_squares. unwrap( ) [ ( ) ] . abs_diff_eq( & rssq, 1e-12 ) ) ;
456- assert ! ( res
457- . solution
458- . abs_diff_eq( & array![ -0.428571428571429 , 0.85714285714285 ] , 1e-12 ) ) ;
459- }
460-
461- #[ test]
462- fn scipy_test_simple_overdetermined_f32 ( ) {
463- let a: Array2 < f32 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
464- let b: Array1 < f32 > = array ! [ 1. , 2. , 3. ] ;
465- let res = a. least_squares ( & b) . unwrap ( ) ;
466- assert_eq ! ( res. rank, 2 ) ;
467- let b_hat = a. dot ( & res. solution ) ;
468- let rssq = ( & b - & b_hat) . mapv ( |x| x. powi ( 2 ) ) . sum ( ) ;
469- assert ! ( res. residual_sum_of_squares. unwrap( ) [ ( ) ] . abs_diff_eq( & rssq, 1e-6 ) ) ;
470- assert ! ( res
471- . solution
472- . abs_diff_eq( & array![ -0.428571428571429 , 0.85714285714285 ] , 1e-6 ) ) ;
473- }
474-
475- fn c ( re : f64 , im : f64 ) -> Complex < f64 > {
476- Complex :: new ( re, im)
477- }
478-
479- #[ test]
480- fn scipy_test_simple_overdetermined_complex ( ) {
481- let a: Array2 < c64 > = array ! [
482- [ c( 1. , 2. ) , c( 2. , 0. ) ] ,
483- [ c( 4. , 0. ) , c( 5. , 0. ) ] ,
484- [ c( 3. , 0. ) , c( 4. , 0. ) ]
485- ] ;
486- let b: Array1 < c64 > = array ! [ c( 1. , 0. ) , c( 2. , 4. ) , c( 3. , 0. ) ] ;
487- let res = a. least_squares ( & b) . unwrap ( ) ;
488- assert_eq ! ( res. rank, 2 ) ;
489- let b_hat = a. dot ( & res. solution ) ;
490- let rssq = ( & b_hat - & b) . mapv ( |x| x. powi ( 2 ) . abs ( ) ) . sum ( ) ;
491- assert ! ( res. residual_sum_of_squares. unwrap( ) [ ( ) ] . abs_diff_eq( & rssq, 1e-12 ) ) ;
492- assert ! ( res. solution. abs_diff_eq(
493- & array![
494- c( -0.4831460674157303 , 0.258426966292135 ) ,
495- c( 0.921348314606741 , 0.292134831460674 )
496- ] ,
497- 1e-12
498- ) ) ;
499- }
500-
501- #[ test]
502- fn scipy_test_simple_underdetermined ( ) {
503- let a: Array2 < f64 > = array ! [ [ 1. , 2. , 3. ] , [ 4. , 5. , 6. ] ] ;
504- let b: Array1 < f64 > = array ! [ 1. , 2. ] ;
505- let res = a. least_squares ( & b) . unwrap ( ) ;
506- assert_eq ! ( res. rank, 2 ) ;
507- assert ! ( res. residual_sum_of_squares. is_none( ) ) ;
508- let expected = array ! [ -0.055555555555555 , 0.111111111111111 , 0.277777777777777 ] ;
509- assert ! ( res. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
510- }
511-
512- /// This test case tests the underdetermined case for multiple right hand
513- /// sides. Adapted from scipy lstsq tests.
514- #[ test]
515- fn scipy_test_simple_underdetermined_nrhs ( ) {
516- let a: Array2 < f64 > = array ! [ [ 1. , 2. , 3. ] , [ 4. , 5. , 6. ] ] ;
517- let b: Array2 < f64 > = array ! [ [ 1. , 1. ] , [ 2. , 2. ] ] ;
518- let res = a. least_squares ( & b) . unwrap ( ) ;
519- assert_eq ! ( res. rank, 2 ) ;
520- assert ! ( res. residual_sum_of_squares. is_none( ) ) ;
521- let expected = array ! [
522- [ -0.055555555555555 , -0.055555555555555 ] ,
523- [ 0.111111111111111 , 0.111111111111111 ] ,
524- [ 0.277777777777777 , 0.277777777777777 ]
525- ] ;
526- assert ! ( res. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
527- }
419+ use ndarray:: * ;
528420
529421 //
530422 // Test that the different lest squares traits work as intended on the
@@ -554,23 +446,23 @@ mod tests {
554446 }
555447
556448 #[ test]
557- fn test_least_squares_on_arc ( ) {
449+ fn on_arc ( ) {
558450 let a: ArcArray2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] . into_shared ( ) ;
559451 let b: ArcArray1 < f64 > = array ! [ 1. , 2. , 3. ] . into_shared ( ) ;
560452 let res = a. least_squares ( & b) . unwrap ( ) ;
561453 assert_result ( & a, & b, & res) ;
562454 }
563455
564456 #[ test]
565- fn test_least_squares_on_cow ( ) {
457+ fn on_cow ( ) {
566458 let a = CowArray :: from ( array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ) ;
567459 let b = CowArray :: from ( array ! [ 1. , 2. , 3. ] ) ;
568460 let res = a. least_squares ( & b) . unwrap ( ) ;
569461 assert_result ( & a, & b, & res) ;
570462 }
571463
572464 #[ test]
573- fn test_least_squares_on_view ( ) {
465+ fn on_view ( ) {
574466 let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
575467 let b: Array1 < f64 > = array ! [ 1. , 2. , 3. ] ;
576468 let av = a. view ( ) ;
@@ -580,7 +472,7 @@ mod tests {
580472 }
581473
582474 #[ test]
583- fn test_least_squares_on_view_mut ( ) {
475+ fn on_view_mut ( ) {
584476 let mut a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
585477 let mut b: Array1 < f64 > = array ! [ 1. , 2. , 3. ] ;
586478 let av = a. view_mut ( ) ;
@@ -590,7 +482,7 @@ mod tests {
590482 }
591483
592484 #[ test]
593- fn test_least_squares_into_on_owned ( ) {
485+ fn into_on_owned ( ) {
594486 let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
595487 let b: Array1 < f64 > = array ! [ 1. , 2. , 3. ] ;
596488 let ac = a. clone ( ) ;
@@ -600,7 +492,7 @@ mod tests {
600492 }
601493
602494 #[ test]
603- fn test_least_squares_into_on_arc ( ) {
495+ fn into_on_arc ( ) {
604496 let a: ArcArray2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] . into_shared ( ) ;
605497 let b: ArcArray1 < f64 > = array ! [ 1. , 2. , 3. ] . into_shared ( ) ;
606498 let a2 = a. clone ( ) ;
@@ -610,7 +502,7 @@ mod tests {
610502 }
611503
612504 #[ test]
613- fn test_least_squares_into_on_cow ( ) {
505+ fn into_on_cow ( ) {
614506 let a = CowArray :: from ( array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ) ;
615507 let b = CowArray :: from ( array ! [ 1. , 2. , 3. ] ) ;
616508 let a2 = a. clone ( ) ;
@@ -620,7 +512,7 @@ mod tests {
620512 }
621513
622514 #[ test]
623- fn test_least_squares_in_place_on_owned ( ) {
515+ fn in_place_on_owned ( ) {
624516 let a = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
625517 let b = array ! [ 1. , 2. , 3. ] ;
626518 let mut a2 = a. clone ( ) ;
@@ -630,7 +522,7 @@ mod tests {
630522 }
631523
632524 #[ test]
633- fn test_least_squares_in_place_on_cow ( ) {
525+ fn in_place_on_cow ( ) {
634526 let a = CowArray :: from ( array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ) ;
635527 let b = CowArray :: from ( array ! [ 1. , 2. , 3. ] ) ;
636528 let mut a2 = a. clone ( ) ;
@@ -640,7 +532,7 @@ mod tests {
640532 }
641533
642534 #[ test]
643- fn test_least_squares_in_place_on_mut_view ( ) {
535+ fn in_place_on_mut_view ( ) {
644536 let a = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
645537 let b = array ! [ 1. , 2. , 3. ] ;
646538 let mut a2 = a. clone ( ) ;
@@ -651,95 +543,30 @@ mod tests {
651543 assert_result ( & a, & b, & res) ;
652544 }
653545
654- //
655- // Test cases taken from the netlib documentation at
656- // https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code
657- //
658- #[ test]
659- fn netlib_lapack_example_for_dgels_1 ( ) {
660- let a: Array2 < f64 > = array ! [
661- [ 1. , 1. , 1. ] ,
662- [ 2. , 3. , 4. ] ,
663- [ 3. , 5. , 2. ] ,
664- [ 4. , 2. , 5. ] ,
665- [ 5. , 4. , 3. ]
666- ] ;
667- let b: Array1 < f64 > = array ! [ -10. , 12. , 14. , 16. , 18. ] ;
668- let expected: Array1 < f64 > = array ! [ 2. , 1. , 1. ] ;
669- let result = a. least_squares ( & b) . unwrap ( ) ;
670- assert ! ( result. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
671-
672- let residual = b - a. dot ( & result. solution ) ;
673- let resid_ssq = result. residual_sum_of_squares . unwrap ( ) ;
674- assert ! ( ( resid_ssq[ ( ) ] - residual. dot( & residual) ) . abs( ) < 1e-12 ) ;
675- }
676-
677- #[ test]
678- fn netlib_lapack_example_for_dgels_2 ( ) {
679- let a: Array2 < f64 > = array ! [
680- [ 1. , 1. , 1. ] ,
681- [ 2. , 3. , 4. ] ,
682- [ 3. , 5. , 2. ] ,
683- [ 4. , 2. , 5. ] ,
684- [ 5. , 4. , 3. ]
685- ] ;
686- let b: Array1 < f64 > = array ! [ -3. , 14. , 12. , 16. , 16. ] ;
687- let expected: Array1 < f64 > = array ! [ 1. , 1. , 2. ] ;
688- let result = a. least_squares ( & b) . unwrap ( ) ;
689- assert ! ( result. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
690-
691- let residual = b - a. dot ( & result. solution ) ;
692- let resid_ssq = result. residual_sum_of_squares . unwrap ( ) ;
693- assert ! ( ( resid_ssq[ ( ) ] - residual. dot( & residual) ) . abs( ) < 1e-12 ) ;
694- }
695-
696- #[ test]
697- fn netlib_lapack_example_for_dgels_nrhs ( ) {
698- let a: Array2 < f64 > = array ! [
699- [ 1. , 1. , 1. ] ,
700- [ 2. , 3. , 4. ] ,
701- [ 3. , 5. , 2. ] ,
702- [ 4. , 2. , 5. ] ,
703- [ 5. , 4. , 3. ]
704- ] ;
705- let b: Array2 < f64 > = array ! [ [ -10. , -3. ] , [ 12. , 14. ] , [ 14. , 12. ] , [ 16. , 16. ] , [ 18. , 16. ] ] ;
706- let expected: Array2 < f64 > = array ! [ [ 2. , 1. ] , [ 1. , 1. ] , [ 1. , 2. ] ] ;
707- let result = a. least_squares ( & b) . unwrap ( ) ;
708- assert ! ( result. solution. abs_diff_eq( & expected, 1e-12 ) ) ;
709-
710- let residual = & b - & a. dot ( & result. solution ) ;
711- let residual_ssq = residual. mapv ( |x| x. powi ( 2 ) ) . sum_axis ( Axis ( 0 ) ) ;
712- assert ! ( result
713- . residual_sum_of_squares
714- . unwrap( )
715- . abs_diff_eq( & residual_ssq, 1e-12 ) ) ;
716- }
717-
718546 //
719547 // Testing error cases
720548 //
721- use crate :: layout:: MatrixLayout ;
722549
723550 #[ test]
724- fn test_incompatible_shape_error_on_mismatching_num_rows ( ) {
551+ fn incompatible_shape_error_on_mismatching_num_rows ( ) {
725552 let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
726553 let b: Array1 < f64 > = array ! [ 1. , 2. ] ;
727554 let res = a. least_squares ( & b) ;
728555 match res {
729- Err ( LinalgError :: Lapack ( err) ) if matches ! ( err, lapack :: error:: Error :: InvalidShape ) => { }
556+ Err ( LinalgError :: Lapack ( err) ) if matches ! ( err, lax :: error:: Error :: InvalidShape ) => { }
730557 _ => panic ! ( "Expected Err()" ) ,
731558 }
732559 }
733560
734561 #[ test]
735- fn test_incompatible_shape_error_on_mismatching_layout ( ) {
562+ fn incompatible_shape_error_on_mismatching_layout ( ) {
736563 let a: Array2 < f64 > = array ! [ [ 1. , 2. ] , [ 4. , 5. ] , [ 3. , 4. ] ] ;
737564 let b = array ! [ [ 1. ] , [ 2. ] ] . t ( ) . to_owned ( ) ;
738565 assert_eq ! ( b. layout( ) . unwrap( ) , MatrixLayout :: F { col: 2 , lda: 1 } ) ;
739566
740567 let res = a. least_squares ( & b) ;
741568 match res {
742- Err ( LinalgError :: Lapack ( err) ) if matches ! ( err, lapack :: error:: Error :: InvalidShape ) => { }
569+ Err ( LinalgError :: Lapack ( err) ) if matches ! ( err, lax :: error:: Error :: InvalidShape ) => { }
743570 _ => panic ! ( "Expected Err()" ) ,
744571 }
745572 }
0 commit comments