Skip to content

Commit 70c7601

Browse files
SteveBrondersyclik
authored andcommitted
more cleanup
1 parent c150da9 commit 70c7601

2 files changed

Lines changed: 78 additions & 101 deletions

File tree

stan/math/prim/fun/log_gamma_q_dgamma.hpp

Lines changed: 24 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,6 @@
2020
namespace stan {
2121
namespace math {
2222

23-
/**
24-
* Result structure containing log(Q(a,z)) and its gradient with respect to a.
25-
*
26-
* @tparam T return type
27-
*/
28-
template <typename T>
29-
struct log_gamma_q_result {
30-
T log_q; ///< log(Q(a,z)) where Q is upper regularized incomplete gamma
31-
T dlog_q_da; ///< d/da log(Q(a,z))
32-
};
33-
3423
namespace internal {
3524

3625
/**
@@ -41,40 +30,32 @@ namespace internal {
4130
* @tparam T_z Type of value parameter z (double or fvar types)
4231
* @param a Shape parameter
4332
* @param z Value at which to evaluate
44-
* @param precision Convergence threshold
33+
* @param precision Convergence threshold, default of sqrt(machine_epsilon)
4534
* @param max_steps Maximum number of continued fraction iterations
4635
* @return log(Q(a,z)) with the return type of T_a and T_z
4736
*/
4837
template <typename T_a, typename T_z>
4938
inline return_type_t<T_a, T_z> log_q_gamma_cf(const T_a& a, const T_z& z,
50-
double precision = 1e-16,
39+
double precision = 1.49012e-08,
5140
int max_steps = 250) {
5241
using T_return = return_type_t<T_a, T_z>;
5342
const auto log_prefactor = a * log(z) - z - lgamma(a);
5443

55-
auto b = z + 1.0 - a;
56-
auto C = (fabs(value_of(b)) >= EPSILON) ? b : std::decay_t<decltype(b)>(EPSILON);
44+
auto b_init = z + 1.0 - a;
45+
auto C = (fabs(value_of(b_init)) >= EPSILON) ? b_init : std::decay_t<decltype(b_init)>(EPSILON);
5746
auto D = 0.0;
5847
auto f = C;
59-
6048
for (int i = 1; i <= max_steps; ++i) {
6149
T_a an = -i * (i - a);
62-
b += 2.0;
63-
50+
const auto b = b_init + 2.0 * i;
6451
D = b + an * D;
65-
if (fabs(D) < EPSILON) {
66-
D = EPSILON;
67-
}
52+
D = (fabs(value_of(D)) >= EPSILON) ? D : std::decay_t<decltype(D)>(EPSILON);
6853
C = b + an / C;
69-
if (fabs(C) < EPSILON) {
70-
C = EPSILON;
71-
}
72-
54+
C = (fabs(value_of(C)) >= EPSILON) ? C : std::decay_t<decltype(C)>(EPSILON);
7355
D = inv(D);
7456
auto delta = C * D;
7557
f *= delta;
76-
77-
const double delta_m1 = value_of(fabs(value_of(delta) - 1.0));
58+
const auto delta_m1 = fabs(value_of(delta) - 1.0);
7859
if (delta_m1 < precision) {
7960
break;
8061
}
@@ -97,40 +78,40 @@ inline return_type_t<T_a, T_z> log_q_gamma_cf(const T_a& a, const T_z& z,
9778
* @tparam T_z type of the value parameter
9879
* @param a shape parameter (must be positive)
9980
* @param z value parameter (must be non-negative)
100-
* @param precision convergence threshold
81+
* @param precision convergence threshold, default of sqrt(machine_epsilon)
10182
* @param max_steps maximum iterations for continued fraction
10283
* @return structure containing log(Q(a,z)) and d/da log(Q(a,z))
10384
*/
10485
template <typename T_a, typename T_z>
105-
inline log_gamma_q_result<return_type_t<T_a, T_z>> log_gamma_q_dgamma(
106-
const T_a& a, const T_z& z, double precision = 1e-16, int max_steps = 250) {
86+
inline auto log_gamma_q_dgamma(
87+
const T_a& a, const T_z& z, double precision = 1.49012e-08, int max_steps = 250) {
10788
using T_return = return_type_t<T_a, T_z>;
108-
const auto a_dbl = value_of(a);
109-
const auto z_dbl = value_of(z);
89+
const auto a_val = value_of(a);
90+
const auto z_val = value_of(z);
11091
// For z > a + 1, use continued fraction for better numerical stability
111-
if (z_dbl > a_dbl + 1.0) {
112-
log_gamma_q_result<T_return> result{internal::log_q_gamma_cf(a_dbl, z_dbl, precision, max_steps), 0.0};
92+
if (z_val > a_val + 1.0) {
93+
std::pair<T_return, T_return> result{internal::log_q_gamma_cf(a_val, z_val, precision, max_steps), 0.0};
11394
// For gradient, use: d/da log(Q) = (1/Q) * dQ/da
11495
// grad_reg_inc_gamma computes dQ/da
115-
const auto Q_val = exp(result.log_q);
96+
const auto Q_val = exp(result.first);
11697
const auto dQ_da
117-
= grad_reg_inc_gamma(a_dbl, z_dbl, tgamma(a_dbl), digamma(a_dbl));
118-
result.dlog_q_da = dQ_da / Q_val;
98+
= grad_reg_inc_gamma(a_val, z_val, tgamma(a_val), digamma(a_val));
99+
result.second = dQ_da / Q_val;
119100
return result;
120101
} else {
121102
// For z <= a + 1, use log1m(P(a,z)) for better numerical accuracy
122-
const auto P_val = gamma_p(a_dbl, z_dbl);
123-
log_gamma_q_result<T_return> result{log1m(P_val), 0.0};
103+
const auto P_val = gamma_p(a_val, z_val);
104+
std::pair<T_return, T_return> result{log1m(P_val), 0.0};
124105
// Gradient: d/da log(Q) = (1/Q) * dQ/da
125106
// grad_reg_inc_gamma computes dQ/da
126-
const auto Q_val = exp(result.log_q);
107+
const auto Q_val = exp(result.first);
127108
if (Q_val > 0) {
128109
const auto dQ_da
129-
= grad_reg_inc_gamma(a_dbl, z_dbl, tgamma(a_dbl), digamma(a_dbl));
130-
result.dlog_q_da = dQ_da / Q_val;
110+
= grad_reg_inc_gamma(a_val, z_val, tgamma(a_val), digamma(a_val));
111+
result.second = dQ_da / Q_val;
131112
} else {
132113
// Fallback if Q rounds to zero - use asymptotic approximation
133-
result.dlog_q_da = log(z_dbl) - digamma(a_dbl);
114+
result.second = log(z_val) - digamma(a_val);
134115
}
135116
return result;
136117
}

stan/math/prim/prob/gamma_lccdf.hpp

Lines changed: 54 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -26,72 +26,70 @@
2626
namespace stan {
2727
namespace math {
2828
namespace internal {
29-
template <typename T>
30-
struct Q_eval {
31-
T log_Q{0.0};
32-
T dlogQ_dalpha{0.0};
33-
};
3429

3530
/**
3631
* Computes log q and d(log q) / d(alpha) using continued fraction.
3732
*/
38-
template <bool any_fvar, bool partials_fvar, typename T_shape, typename T>
39-
static inline auto eval_q_cf(const T& alpha, const T& beta_y) {
40-
Q_eval<return_type_t<T>> out;
33+
template <bool any_fvar, bool partials_fvar, typename T_shape, typename T1, typename T2>
34+
inline auto eval_q_cf(const T1& alpha, const T2& beta_y) {
35+
using scalar_t = return_type_t<T1, T2>;
36+
using ret_t = std::pair<scalar_t, scalar_t>;
4137
if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
4238
auto log_q_result
4339
= log_gamma_q_dgamma(value_of_rec(alpha), value_of_rec(beta_y));
44-
out.log_Q = log_q_result.log_q;
45-
out.dlogQ_dalpha = log_q_result.dlog_q_da;
40+
if (likely(std::isfinite(value_of_rec(log_q_result.first)))) {
41+
return std::optional{log_q_result};
42+
} else {
43+
return std::optional<ret_t>{std::nullopt};
44+
}
4645
} else {
47-
out.log_Q = internal::log_q_gamma_cf(alpha, beta_y);
46+
ret_t out{internal::log_q_gamma_cf(alpha, beta_y), 0.0};
47+
if (unlikely(!std::isfinite(value_of_rec(out.first)))) {
48+
return std::optional<ret_t>{std::nullopt};
49+
}
4850
if constexpr (is_autodiff_v<T_shape>) {
4951
if constexpr (!partials_fvar) {
50-
out.dlogQ_dalpha
52+
out.second
5153
= grad_reg_inc_gamma(alpha, beta_y, tgamma(alpha), digamma(alpha))
52-
/ exp(out.log_Q);
54+
/ exp(out.first);
5355
} else {
54-
T alpha_unit = alpha;
56+
auto alpha_unit = alpha;
5557
alpha_unit.d_ = 1;
56-
T beta_y_unit = beta_y;
58+
auto beta_y_unit = beta_y;
5759
beta_y_unit.d_ = 0;
58-
T log_Q_fvar = internal::log_q_gamma_cf(alpha_unit, beta_y_unit);
59-
out.dlogQ_dalpha = log_Q_fvar.d_;
60+
auto log_Q_fvar = internal::log_q_gamma_cf(alpha_unit, beta_y_unit);
61+
out.second = log_Q_fvar.d_;
6062
}
6163
}
62-
}
63-
64-
if (std::isfinite(value_of_rec(out.log_Q))) {
65-
return std::optional<Q_eval<return_type_t<T>>>{out};
66-
} else {
67-
return std::optional<Q_eval<return_type_t<T>>>{std::nullopt};
64+
return std::optional{out};
6865
}
6966
}
7067

7168
/**
7269
* Computes log q and d(log q) / d(alpha) using log1m.
7370
*/
74-
template <bool partials_fvar, typename T_shape, typename T>
75-
static inline Q_eval<T> eval_q_log1m(const T& alpha, const T& beta_y) {
76-
Q_eval<T> out;
77-
out.log_Q = log1m(gamma_p(alpha, beta_y));
78-
if (!std::isfinite(value_of_rec(out.log_Q))) {
79-
return out;
71+
template <bool partials_fvar, typename T_shape, typename T1, typename T2>
72+
inline auto eval_q_log1m(const T1& alpha, const T2& beta_y) {
73+
using scalar_t = return_type_t<T1, T2>;
74+
using ret_t = std::pair<scalar_t, scalar_t>;
75+
ret_t out{log1m(gamma_p(alpha, beta_y)), 0.0};
76+
if (unlikely(!std::isfinite(value_of_rec(out.first)))) {
77+
return std::optional<ret_t>{std::nullopt};
8078
}
8179
if constexpr (is_autodiff_v<T_shape>) {
8280
if constexpr (partials_fvar) {
83-
T alpha_unit = alpha;
81+
auto alpha_unit = alpha;
8482
alpha_unit.d_ = 1;
85-
T beta_unit = beta_y;
83+
auto beta_unit = beta_y;
8684
beta_unit.d_ = 0;
87-
T log_Q_fvar = log1m(gamma_p(alpha_unit, beta_unit));
88-
out.dlogQ_dalpha = log_Q_fvar.d_;
85+
auto log_Q_fvar = log1m(gamma_p(alpha_unit, beta_unit));
86+
out.second = log_Q_fvar.d_;
8987
} else {
90-
out.dlogQ_dalpha
91-
= -grad_reg_lower_inc_gamma(alpha, beta_y) / exp(out.log_Q);
88+
out.second
89+
= -grad_reg_lower_inc_gamma(alpha, beta_y) / exp(out.first);
9290
}
9391
}
94-
return out;
92+
return std::optional{out};
9593
}
9694
} // namespace internal
9795

@@ -134,58 +132,56 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
134132
for (size_t n = 0; n < N; n++) {
135133
// Explicit results for extreme values
136134
// The gradients are technically ill-defined, but treated as zero
137-
const T_partials_return y_dbl = y_vec.val(n);
138-
if (y_dbl == 0.0) {
135+
const auto y_val = y_vec.val(n);
136+
if (y_val == 0.0) {
139137
continue;
140138
}
141-
if (y_dbl == INFTY) {
139+
if (y_val == INFTY) {
142140
return ops_partials.build(negative_infinity());
143141
}
144142

145-
const T_partials_return alpha_dbl = alpha_vec.val(n);
146-
const T_partials_return beta_dbl = beta_vec.val(n);
143+
const auto alpha_val = alpha_vec.val(n);
144+
const auto beta_val = beta_vec.val(n);
147145

148-
const T_partials_return beta_y = beta_dbl * y_dbl;
146+
const auto beta_y = beta_val * y_val;
149147
if (beta_y == INFTY) {
150148
return ops_partials.build(negative_infinity());
151149
}
152-
153-
const bool use_continued_fraction = beta_y > alpha_dbl + 1.0;
154-
std::optional<internal::Q_eval<T_partials_return>> result;
155-
if (use_continued_fraction) {
156-
result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(alpha_dbl, beta_y);
150+
std::optional<std::pair<T_partials_return, T_partials_return>> result;
151+
if (beta_y > alpha_val + 1.0) {
152+
result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(alpha_val, beta_y);
157153
} else {
158-
result = internal::eval_q_log1m<partials_fvar, T_shape>(alpha_dbl, beta_y);
154+
result = internal::eval_q_log1m<partials_fvar, T_shape>(alpha_val, beta_y);
159155
if (!result && beta_y > 0.0) {
160156
// Fallback to continued fraction if log1m fails
161-
result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(alpha_dbl, beta_y);
157+
result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(alpha_val, beta_y);
162158
}
163159
}
164-
if (!result) {
160+
if (unlikely(!result)) {
165161
return ops_partials.build(negative_infinity());
166162
}
167163

168-
P += result->log_Q;
164+
P += result->first;
169165

170166
if constexpr (is_autodiff_v<T_y> || is_autodiff_v<T_inv_scale>) {
171-
const T_partials_return log_y = log(y_dbl);
172-
const T_partials_return alpha_minus_one = fma(alpha_dbl, log_y, -log_y);
167+
const auto log_y = log(y_val);
168+
const auto alpha_minus_one = fma(alpha_val, log_y, -log_y);
173169

174-
const T_partials_return log_pdf = alpha_dbl * log(beta_dbl)
175-
- lgamma(alpha_dbl) + alpha_minus_one
170+
const auto log_pdf = alpha_val * log(beta_val)
171+
- lgamma(alpha_val) + alpha_minus_one
176172
- beta_y;
177173

178-
const T_partials_return hazard = exp(log_pdf - result->log_Q); // f/Q
174+
const auto hazard = exp(log_pdf - result->first); // f/Q
179175

180176
if constexpr (is_autodiff_v<T_y>) {
181177
partials<0>(ops_partials)[n] -= hazard;
182178
}
183179
if constexpr (is_autodiff_v<T_inv_scale>) {
184-
partials<2>(ops_partials)[n] -= (y_dbl / beta_dbl) * hazard;
180+
partials<2>(ops_partials)[n] -= (y_val / beta_val) * hazard;
185181
}
186182
}
187183
if constexpr (is_autodiff_v<T_shape>) {
188-
partials<1>(ops_partials)[n] += result->dlogQ_dalpha;
184+
partials<1>(ops_partials)[n] += result->second;
189185
}
190186
}
191187
return ops_partials.build(P);

0 commit comments

Comments
 (0)