@@ -4,10 +4,8 @@ use ndarray::*;
44use ndarray_linalg:: * ;
55
66/// A is square. `x = A^{-1} b`, `|b - Ax| = 0`
7- #[ test]
8- fn least_squares_exact ( ) {
9- let a: Array2 < f64 > = random ( ( 3 , 3 ) ) ;
10- let b: Array1 < f64 > = random ( 3 ) ;
7+ fn test_exact < T : Scalar + Lapack > ( a : Array2 < T > ) {
8+ let b: Array1 < T > = random ( 3 ) ;
119 let result = a. least_squares ( & b) . unwrap ( ) ;
1210 // unpack result
1311 let x = result. solution ;
@@ -17,19 +15,43 @@ fn least_squares_exact() {
1715 assert_eq ! ( result. rank, 3 ) ;
1816
1917 // |b - Ax| == 0
20- assert ! ( residual_l2_square < 1.0e-7 ) ;
18+ assert ! ( residual_l2_square < T :: real ( 1.0e-4 ) ) ;
2119
2220 // b == Ax
2321 let ax = a. dot ( & x) ;
24- assert_close_l2 ! ( & b, & ax, 1.0e-7 ) ;
22+ assert_close_l2 ! ( & b, & ax, T :: real ( 1.0e-4 ) ) ;
2523}
2624
25+ macro_rules! impl_exact {
26+ ( $scalar: ty) => {
27+ paste:: item! {
28+ #[ test]
29+ fn [ <least_squares_ $scalar _exact>] ( ) {
30+ let a: Array2 <f64 > = random( ( 3 , 3 ) ) ;
31+ test_exact( a)
32+ }
33+
34+ #[ test]
35+ fn [ <least_squares_ $scalar _exact_t>] ( ) {
36+ let a: Array2 <f64 > = random( ( 3 , 3 ) . f( ) ) ;
37+ test_exact( a)
38+ }
39+ }
40+ } ;
41+ }
42+
43+ impl_exact ! ( f32 ) ;
44+ impl_exact ! ( f64 ) ;
45+ impl_exact ! ( c32) ;
46+ impl_exact ! ( c64) ;
47+
2748/// #column < #row case.
2849/// Linear problem is overdetermined, `|b - Ax| > 0`.
29- #[ test]
30- fn least_squares_overdetermined ( ) {
31- let a: Array2 < f64 > = random ( ( 4 , 3 ) ) ;
32- let b: Array1 < f64 > = random ( 4 ) ;
50+ fn test_overdetermined < T : Scalar + Lapack > ( a : Array2 < T > )
51+ where
52+ T :: Real : AbsDiffEq < Epsilon = T :: Real > ,
53+ {
54+ let b: Array1 < T > = random ( 4 ) ;
3355 let result = a. least_squares ( & b) . unwrap ( ) ;
3456 // unpack result
3557 let x = result. solution ;
@@ -40,24 +62,68 @@ fn least_squares_overdetermined() {
4062
4163 // eval `residual = b - Ax`
4264 let residual = & b - & a. dot ( & x) ;
43- assert ! ( residual_l2_square. abs_diff_eq( & residual. norm_l2( ) . powi( 2 ) , 1e-12 ) ) ;
65+ assert ! ( residual_l2_square. abs_diff_eq( & residual. norm_l2( ) . powi( 2 ) , T :: real ( 1.0e-4 ) ) ) ;
4466
4567 // `|residual| < |b|`
4668 assert ! ( residual. norm_l2( ) < b. norm_l2( ) ) ;
4769}
4870
71+ macro_rules! impl_overdetermined {
72+ ( $scalar: ty) => {
73+ paste:: item! {
74+ #[ test]
75+ fn [ <least_squares_ $scalar _overdetermined>] ( ) {
76+ let a: Array2 <f64 > = random( ( 4 , 3 ) ) ;
77+ test_overdetermined( a)
78+ }
79+
80+ #[ test]
81+ fn [ <least_squares_ $scalar _overdetermined_t>] ( ) {
82+ let a: Array2 <f64 > = random( ( 4 , 3 ) . f( ) ) ;
83+ test_overdetermined( a)
84+ }
85+ }
86+ } ;
87+ }
88+
89+ impl_overdetermined ! ( f32 ) ;
90+ impl_overdetermined ! ( f64 ) ;
91+ impl_overdetermined ! ( c32) ;
92+ impl_overdetermined ! ( c64) ;
93+
4994/// #column > #row case.
5095/// Linear problem is underdetermined, `|b - Ax| = 0` and `x` is not unique
51- #[ test]
52- fn least_squares_underdetermined ( ) {
53- let a: Array2 < f64 > = random ( ( 3 , 4 ) ) ;
54- let b: Array1 < f64 > = random ( 3 ) ;
96+ fn test_underdetermined < T : Scalar + Lapack > ( a : Array2 < T > ) {
97+ let b: Array1 < T > = random ( 3 ) ;
5598 let result = a. least_squares ( & b) . unwrap ( ) ;
5699 assert_eq ! ( result. rank, 3 ) ;
57100 assert ! ( result. residual_sum_of_squares. is_none( ) ) ;
58101
59102 // b == Ax
60103 let x = result. solution ;
61104 let ax = a. dot ( & x) ;
62- assert_close_l2 ! ( & b, & ax, 1.0e-7 ) ;
105+ assert_close_l2 ! ( & b, & ax, T :: real( 1.0e-4 ) ) ;
106+ }
107+
108+ macro_rules! impl_underdetermined {
109+ ( $scalar: ty) => {
110+ paste:: item! {
111+ #[ test]
112+ fn [ <least_squares_ $scalar _underdetermined>] ( ) {
113+ let a: Array2 <f64 > = random( ( 3 , 4 ) ) ;
114+ test_underdetermined( a)
115+ }
116+
117+ #[ test]
118+ fn [ <least_squares_ $scalar _underdetermined_t>] ( ) {
119+ let a: Array2 <f64 > = random( ( 3 , 4 ) . f( ) ) ;
120+ test_underdetermined( a)
121+ }
122+ }
123+ } ;
63124}
125+
126+ impl_underdetermined ! ( f32 ) ;
127+ impl_underdetermined ! ( f64 ) ;
128+ impl_underdetermined ! ( c32) ;
129+ impl_underdetermined ! ( c64) ;
0 commit comments