@@ -28,6 +28,28 @@ pub struct Tridiagonal<A: Scalar> {
2828 pub du : Vec < A > ,
2929}
3030
31+ impl < A : Scalar > Tridiagonal < A > {
32+ fn opnorm_one ( & self ) -> A :: Real {
33+ let n = self . l . len ( ) as usize ;
34+ let mut col_sum: Vec < A :: Real > = self . d . iter ( ) . map ( |val| val. abs ( ) ) . collect ( ) ;
35+ for i in 0 ..n - 1 {
36+ if i < n - 1 {
37+ col_sum[ i] += self . dl [ i + 1 ] . abs ( ) ;
38+ }
39+ if i > 0 {
40+ col_sum[ i] += self . du [ i - 1 ] . abs ( ) ;
41+ }
42+ }
43+ let mut max = A :: Real :: zero ( ) ;
44+ for & val in & col_sum {
45+ if max < val {
46+ max = val;
47+ }
48+ }
49+ max
50+ }
51+ }
52+
3153/// Represents the LU factorization of a tridiagonal matrix `A` as `A = P*L*U`.
3254#[ derive( Clone , PartialEq ) ]
3355pub struct LUFactorizedTridiagonal < A : Scalar > {
@@ -41,6 +63,8 @@ pub struct LUFactorizedTridiagonal<A: Scalar> {
4163 pub du2 : Vec < A > ,
4264 /// The pivot indices that define the permutation matrix `P`.
4365 pub ipiv : Pivot ,
66+
67+ a_opnorm_one : A :: Real ,
4468}
4569
4670impl < A : Scalar > Index < ( i32 , i32 ) > for Tridiagonal < A > {
@@ -66,6 +90,14 @@ impl<A: Scalar> Index<(i32, i32)> for Tridiagonal<A> {
6690 }
6791}
6892
93+ impl < A : Scalar > Index < [ i32 ; 2 ] > for Tridiagonal < A > {
94+ type Output = A ;
95+ #[ inline]
96+ fn index ( & self , [ row, col] : [ i32 ; 2 ] ) -> & A {
97+ & self [ ( row, col) ]
98+ }
99+ }
100+
69101impl < A : Scalar > IndexMut < ( i32 , i32 ) > for Tridiagonal < A > {
70102 #[ inline]
71103 fn index_mut ( & mut self , ( row, col) : ( i32 , i32 ) ) -> & mut A {
@@ -88,11 +120,18 @@ impl<A: Scalar> IndexMut<(i32, i32)> for Tridiagonal<A> {
88120 }
89121}
90122
123+ impl < A : Scalar > IndexMut < [ i32 ; 2 ] > for Tridiagonal < A > {
124+ #[ inline]
125+ fn index_mut ( & mut self , [ row, col] : [ i32 ; 2 ] ) -> & mut A {
126+ & mut self [ ( row, col) ]
127+ }
128+ }
129+
91130/// Wraps `*gttrf`, `*gtcon` and `*gttrs`
92131pub trait Tridiagonal_ : Scalar + Sized {
93132 /// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
94133 /// partial pivoting with row interchanges.
95- unsafe fn lu_tridiagonal ( a : & mut Tridiagonal < Self > ) -> Result < ( Vec < Self > , Pivot ) > ;
134+ unsafe fn lu_tridiagonal ( a : Tridiagonal < Self > ) -> Result < LUFactorizedTridiagonal < Self > > ;
96135
97136 unsafe fn rcond_tridiagonal ( lu : & LUFactorizedTridiagonal < Self > ) -> Result < Self :: Real > ;
98137
@@ -107,19 +146,27 @@ pub trait Tridiagonal_: Scalar + Sized {
107146macro_rules! impl_tridiagonal {
108147 ( $scalar: ty, $gttrf: path, $gtcon: path, $gttrs: path) => {
109148 impl Tridiagonal_ for $scalar {
110- unsafe fn lu_tridiagonal( a: & mut Tridiagonal <Self >) -> Result <( Vec <Self >, Pivot ) > {
149+ unsafe fn lu_tridiagonal(
150+ mut a: Tridiagonal <Self >,
151+ ) -> Result <LUFactorizedTridiagonal <Self >> {
111152 let ( n, _) = a. l. size( ) ;
112153 let mut du2 = vec![ Zero :: zero( ) ; ( n - 2 ) as usize ] ;
113154 let mut ipiv = vec![ 0 ; n as usize ] ;
155+ // We have to calc one-norm before LU factorization
156+ let a_opnorm_one = a. opnorm_one( ) ;
114157 $gttrf( n, & mut a. dl, & mut a. d, & mut a. du, & mut du2, & mut ipiv)
115158 . as_lapack_result( ) ?;
116- Ok ( ( du2, ipiv) )
159+ Ok ( LUFactorizedTridiagonal {
160+ a,
161+ du2,
162+ ipiv,
163+ a_opnorm_one,
164+ } )
117165 }
118166
119167 unsafe fn rcond_tridiagonal( lu: & LUFactorizedTridiagonal <Self >) -> Result <Self :: Real > {
120168 let ( n, _) = lu. a. l. size( ) ;
121169 let ipiv = & lu. ipiv;
122- let anorm = lu. anom;
123170 let mut rcond = Self :: Real :: zero( ) ;
124171 $gtcon(
125172 NormType :: One as u8 ,
@@ -129,7 +176,7 @@ macro_rules! impl_tridiagonal {
129176 & lu. a. du,
130177 & lu. du2,
131178 ipiv,
132- anorm ,
179+ lu . a_opnorm_one ,
133180 & mut rcond,
134181 )
135182 . as_lapack_result( ) ?;
0 commit comments