|
| 1 | +/// Solve least square problem `|b - Ax|` |
1 | 2 | use approx::AbsDiffEq; |
2 | 3 | use ndarray::*; |
3 | 4 | use ndarray_linalg::*; |
4 | | -use num_complex::Complex; |
5 | 5 |
|
6 | | -fn c(re: f64, im: f64) -> Complex<f64> { |
7 | | - Complex::new(re, im) |
8 | | -} |
9 | | - |
10 | | -// |
11 | | -// Test cases taken from the scipy test suite for the scipy lstsq function |
12 | | -// https://github.com/scipy/scipy/blob/v1.4.1/scipy/linalg/tests/basic.py |
13 | | -// |
| 6 | +/// A is square. `x = A^{-1} b`, `|b - Ax| = 0` |
14 | 7 | #[test] |
15 | 8 | fn least_squares_exact() { |
16 | | - let a = array![[1., 20.], [-30., 4.]]; |
17 | | - let bs = vec![ |
18 | | - array![[1., 0.], [0., 1.]], |
19 | | - array![[1.], [0.]], |
20 | | - array![[2., 1.], [-30., 4.]], |
21 | | - ]; |
22 | | - for b in &bs { |
23 | | - let res = a.least_squares(b).unwrap(); |
24 | | - assert_eq!(res.rank, 2); |
25 | | - let b_hat = a.dot(&res.solution); |
26 | | - let rssq = (b - &b_hat).mapv(|x| x.powi(2)).sum_axis(Axis(0)); |
27 | | - assert!(res |
28 | | - .residual_sum_of_squares |
29 | | - .unwrap() |
30 | | - .abs_diff_eq(&rssq, 1e-12)); |
31 | | - assert!(b_hat.abs_diff_eq(&b, 1e-12)); |
32 | | - } |
33 | | -} |
| 9 | + let a: Array2<f64> = random((3, 3)); |
| 10 | + let b: Array1<f64> = random(3); |
| 11 | + let result = a.least_squares(&b).unwrap(); |
| 12 | + // unpack result |
| 13 | + let x = result.solution; |
| 14 | + let residual_l2_square = result.residual_sum_of_squares.unwrap()[()]; |
34 | 15 |
|
35 | | -#[test] |
36 | | -fn least_squares_overdetermined() { |
37 | | - let a: Array2<f64> = array![[1., 2.], [4., 5.], [3., 4.]]; |
38 | | - let b: Array1<f64> = array![1., 2., 3.]; |
39 | | - let res = a.least_squares(&b).unwrap(); |
40 | | - assert_eq!(res.rank, 2); |
41 | | - let b_hat = a.dot(&res.solution); |
42 | | - let rssq = (&b - &b_hat).mapv(|x| x.powi(2)).sum(); |
43 | | - assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12)); |
44 | | - assert!(res |
45 | | - .solution |
46 | | - .abs_diff_eq(&array![-0.428571428571429, 0.85714285714285], 1e-12)); |
47 | | -} |
| 16 | + // must be full-rank |
| 17 | + assert_eq!(result.rank, 3); |
48 | 18 |
|
49 | | -#[test] |
50 | | -fn least_squares_overdetermined_complex() { |
51 | | - let a: Array2<c64> = array![ |
52 | | - [c(1., 2.), c(2., 0.)], |
53 | | - [c(4., 0.), c(5., 0.)], |
54 | | - [c(3., 0.), c(4., 0.)] |
55 | | - ]; |
56 | | - let b: Array1<c64> = array![c(1., 0.), c(2., 4.), c(3., 0.)]; |
57 | | - let res = a.least_squares(&b).unwrap(); |
58 | | - assert_eq!(res.rank, 2); |
59 | | - let b_hat = a.dot(&res.solution); |
60 | | - let rssq = (&b_hat - &b).mapv(|x| x.powi(2).abs()).sum(); |
61 | | - assert!(res.residual_sum_of_squares.unwrap()[()].abs_diff_eq(&rssq, 1e-12)); |
62 | | - assert!(res.solution.abs_diff_eq( |
63 | | - &array![ |
64 | | - c(-0.4831460674157303, 0.258426966292135), |
65 | | - c(0.921348314606741, 0.292134831460674) |
66 | | - ], |
67 | | - 1e-12 |
68 | | - )); |
69 | | -} |
| 19 | + // |b - Ax| == 0 |
| 20 | + assert!(residual_l2_square < 1.0e-7); |
70 | 21 |
|
71 | | -#[test] |
72 | | -fn least_squares_underdetermined() { |
73 | | - let a: Array2<f64> = array![[1., 2., 3.], [4., 5., 6.]]; |
74 | | - let b: Array1<f64> = array![1., 2.]; |
75 | | - let res = a.least_squares(&b).unwrap(); |
76 | | - assert_eq!(res.rank, 2); |
77 | | - assert!(res.residual_sum_of_squares.is_none()); |
78 | | - let expected = array![-0.055555555555555, 0.111111111111111, 0.277777777777777]; |
79 | | - assert!(res.solution.abs_diff_eq(&expected, 1e-12)); |
| 22 | + // b == Ax |
| 23 | + let ax = a.dot(&x); |
| 24 | + assert_close_l2!(&b, &ax, 1.0e-7); |
80 | 25 | } |
81 | 26 |
|
82 | | -/// This test case tests the underdetermined case for multiple right hand |
83 | | -/// sides. Adapted from scipy lstsq tests. |
| 27 | +/// #column < #row case. |
| 28 | +/// Linear problem is overdetermined, `|b - Ax| > 0`. |
84 | 29 | #[test] |
85 | | -fn least_squares_underdetermined_nrhs() { |
86 | | - let a: Array2<f64> = array![[1., 2., 3.], [4., 5., 6.]]; |
87 | | - let b: Array2<f64> = array![[1., 1.], [2., 2.]]; |
88 | | - let res = a.least_squares(&b).unwrap(); |
89 | | - assert_eq!(res.rank, 2); |
90 | | - assert!(res.residual_sum_of_squares.is_none()); |
91 | | - let expected = array![ |
92 | | - [-0.055555555555555, -0.055555555555555], |
93 | | - [0.111111111111111, 0.111111111111111], |
94 | | - [0.277777777777777, 0.277777777777777] |
95 | | - ]; |
96 | | - assert!(res.solution.abs_diff_eq(&expected, 1e-12)); |
97 | | -} |
98 | | - |
99 | | -// |
100 | | -// Test cases taken from the netlib documentation at |
101 | | -// https://www.netlib.org/lapack/lapacke.html#_calling_code_dgels_code |
102 | | -// |
103 | | -#[test] |
104 | | -fn netlib_lapack_example_for_dgels_1() { |
105 | | - let a: Array2<f64> = array![ |
106 | | - [1., 1., 1.], |
107 | | - [2., 3., 4.], |
108 | | - [3., 5., 2.], |
109 | | - [4., 2., 5.], |
110 | | - [5., 4., 3.] |
111 | | - ]; |
112 | | - let b: Array1<f64> = array![-10., 12., 14., 16., 18.]; |
113 | | - let expected: Array1<f64> = array![2., 1., 1.]; |
| 30 | +fn least_squares_overdetermined() { |
| 31 | + let a: Array2<f64> = random((4, 3)); |
| 32 | + let b: Array1<f64> = random(4); |
114 | 33 | let result = a.least_squares(&b).unwrap(); |
115 | | - assert!(result.solution.abs_diff_eq(&expected, 1e-12)); |
| 34 | + // unpack result |
| 35 | + let x = result.solution; |
| 36 | + let residual_l2_square = result.residual_sum_of_squares.unwrap()[()]; |
116 | 37 |
|
117 | | - let residual = b - a.dot(&result.solution); |
118 | | - let resid_ssq = result.residual_sum_of_squares.unwrap(); |
119 | | - assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12); |
120 | | -} |
| 38 | + // Must be full-rank |
| 39 | + assert_eq!(result.rank, 3); |
121 | 40 |
|
122 | | -#[test] |
123 | | -fn netlib_lapack_example_for_dgels_2() { |
124 | | - let a: Array2<f64> = array![ |
125 | | - [1., 1., 1.], |
126 | | - [2., 3., 4.], |
127 | | - [3., 5., 2.], |
128 | | - [4., 2., 5.], |
129 | | - [5., 4., 3.] |
130 | | - ]; |
131 | | - let b: Array1<f64> = array![-3., 14., 12., 16., 16.]; |
132 | | - let expected: Array1<f64> = array![1., 1., 2.]; |
133 | | - let result = a.least_squares(&b).unwrap(); |
134 | | - assert!(result.solution.abs_diff_eq(&expected, 1e-12)); |
| 41 | + // eval `residual = b - Ax` |
| 42 | + let residual = &b - &a.dot(&x); |
| 43 | + assert!(residual_l2_square.abs_diff_eq(&residual.norm_l2().powi(2), 1e-12)); |
135 | 44 |
|
136 | | - let residual = b - a.dot(&result.solution); |
137 | | - let resid_ssq = result.residual_sum_of_squares.unwrap(); |
138 | | - assert!((resid_ssq[()] - residual.dot(&residual)).abs() < 1e-12); |
| 45 | + // `|residual| < |b|` |
| 46 | + assert!(residual.norm_l2() < b.norm_l2()); |
139 | 47 | } |
140 | 48 |
|
| 49 | +/// #column > #row case. |
| 50 | +/// Linear problem is underdetermined, `|b - Ax| = 0` and `x` is not unique |
141 | 51 | #[test] |
142 | | -fn netlib_lapack_example_for_dgels_nrhs() { |
143 | | - let a: Array2<f64> = array![ |
144 | | - [1., 1., 1.], |
145 | | - [2., 3., 4.], |
146 | | - [3., 5., 2.], |
147 | | - [4., 2., 5.], |
148 | | - [5., 4., 3.] |
149 | | - ]; |
150 | | - let b: Array2<f64> = array![[-10., -3.], [12., 14.], [14., 12.], [16., 16.], [18., 16.]]; |
151 | | - let expected: Array2<f64> = array![[2., 1.], [1., 1.], [1., 2.]]; |
| 52 | +fn least_squares_underdetermined() { |
| 53 | + let a: Array2<f64> = random((3, 4)); |
| 54 | + let b: Array1<f64> = random(3); |
152 | 55 | let result = a.least_squares(&b).unwrap(); |
153 | | - assert!(result.solution.abs_diff_eq(&expected, 1e-12)); |
| 56 | + assert_eq!(result.rank, 3); |
| 57 | + assert!(result.residual_sum_of_squares.is_none()); |
154 | 58 |
|
155 | | - let residual = &b - &a.dot(&result.solution); |
156 | | - let residual_ssq = residual.mapv(|x| x.powi(2)).sum_axis(Axis(0)); |
157 | | - assert!(result |
158 | | - .residual_sum_of_squares |
159 | | - .unwrap() |
160 | | - .abs_diff_eq(&residual_ssq, 1e-12)); |
| 59 | + // b == Ax |
| 60 | + let x = result.solution; |
| 61 | + let ax = a.dot(&x); |
| 62 | + assert_close_l2!(&b, &ax, 1.0e-7); |
161 | 63 | } |
0 commit comments