Skip to content

Commit c150da9

Browse files
SteveBrondersyclik
authored andcommitted
cleanup for gamma_lccdf
1 parent 86ac561 commit c150da9

2 files changed

Lines changed: 37 additions & 59 deletions

File tree

stan/math/prim/fun/log_gamma_q_dgamma.hpp

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <stan/math/prim/fun/constants.hpp>
66
#include <stan/math/prim/fun/digamma.hpp>
77
#include <stan/math/prim/fun/exp.hpp>
8+
#include <stan/math/prim/fun/fabs.hpp>
89
#include <stan/math/prim/fun/gamma_p.hpp>
910
#include <stan/math/prim/fun/gamma_q.hpp>
1011
#include <stan/math/prim/fun/grad_reg_inc_gamma.hpp>
@@ -48,18 +49,13 @@ template <typename T_a, typename T_z>
4849
inline return_type_t<T_a, T_z> log_q_gamma_cf(const T_a& a, const T_z& z,
4950
double precision = 1e-16,
5051
int max_steps = 250) {
51-
using stan::math::lgamma;
52-
using stan::math::log;
53-
using stan::math::value_of;
54-
using std::fabs;
5552
using T_return = return_type_t<T_a, T_z>;
53+
const auto log_prefactor = a * log(z) - z - lgamma(a);
5654

57-
const T_return log_prefactor = a * log(z) - z - lgamma(a);
58-
59-
T_return b = z + 1.0 - a;
60-
T_return C = (fabs(value_of(b)) >= EPSILON) ? b : T_return(EPSILON);
61-
T_return D = 0.0;
62-
T_return f = C;
55+
auto b = z + 1.0 - a;
56+
auto C = (fabs(value_of(b)) >= EPSILON) ? b : std::decay_t<decltype(b)>(EPSILON);
57+
auto D = 0.0;
58+
auto f = C;
6359

6460
for (int i = 1; i <= max_steps; ++i) {
6561
T_a an = -i * (i - a);
@@ -75,15 +71,14 @@ inline return_type_t<T_a, T_z> log_q_gamma_cf(const T_a& a, const T_z& z,
7571
}
7672

7773
D = inv(D);
78-
T_return delta = C * D;
74+
auto delta = C * D;
7975
f *= delta;
8076

8177
const double delta_m1 = value_of(fabs(value_of(delta) - 1.0));
8278
if (delta_m1 < precision) {
8379
break;
8480
}
8581
}
86-
8782
return log_prefactor - log(f);
8883
}
8984

@@ -109,45 +104,36 @@ inline return_type_t<T_a, T_z> log_q_gamma_cf(const T_a& a, const T_z& z,
109104
template <typename T_a, typename T_z>
110105
inline log_gamma_q_result<return_type_t<T_a, T_z>> log_gamma_q_dgamma(
111106
const T_a& a, const T_z& z, double precision = 1e-16, int max_steps = 250) {
112-
using std::exp;
113-
using std::log;
114107
using T_return = return_type_t<T_a, T_z>;
115-
116-
const double a_dbl = value_of(a);
117-
const double z_dbl = value_of(z);
118-
119-
log_gamma_q_result<T_return> result;
120-
108+
const auto a_dbl = value_of(a);
109+
const auto z_dbl = value_of(z);
121110
// For z > a + 1, use continued fraction for better numerical stability
122111
if (z_dbl > a_dbl + 1.0) {
123-
result.log_q = internal::log_q_gamma_cf(a_dbl, z_dbl, precision, max_steps);
124-
112+
log_gamma_q_result<T_return> result{internal::log_q_gamma_cf(a_dbl, z_dbl, precision, max_steps), 0.0};
125113
// For gradient, use: d/da log(Q) = (1/Q) * dQ/da
126114
// grad_reg_inc_gamma computes dQ/da
127-
const double Q_val = exp(result.log_q);
128-
const double dQ_da
115+
const auto Q_val = exp(result.log_q);
116+
const auto dQ_da
129117
= grad_reg_inc_gamma(a_dbl, z_dbl, tgamma(a_dbl), digamma(a_dbl));
130118
result.dlog_q_da = dQ_da / Q_val;
131-
119+
return result;
132120
} else {
133121
// For z <= a + 1, use log1m(P(a,z)) for better numerical accuracy
134-
const double P_val = gamma_p(a_dbl, z_dbl);
135-
result.log_q = log1m(P_val);
136-
122+
const auto P_val = gamma_p(a_dbl, z_dbl);
123+
log_gamma_q_result<T_return> result{log1m(P_val), 0.0};
137124
// Gradient: d/da log(Q) = (1/Q) * dQ/da
138125
// grad_reg_inc_gamma computes dQ/da
139-
const double Q_val = exp(result.log_q);
126+
const auto Q_val = exp(result.log_q);
140127
if (Q_val > 0) {
141-
const double dQ_da
128+
const auto dQ_da
142129
= grad_reg_inc_gamma(a_dbl, z_dbl, tgamma(a_dbl), digamma(a_dbl));
143130
result.dlog_q_da = dQ_da / Q_val;
144131
} else {
145132
// Fallback if Q rounds to zero - use asymptotic approximation
146133
result.dlog_q_da = log(z_dbl) - digamma(a_dbl);
147134
}
135+
return result;
148136
}
149-
150-
return result;
151137
}
152138

153139
} // namespace math

stan/math/prim/prob/gamma_lccdf.hpp

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,22 @@
2222
#include <stan/math/prim/fun/log_gamma_q_dgamma.hpp>
2323
#include <stan/math/prim/functor/partials_propagator.hpp>
2424
#include <cmath>
25-
25+
#include <optional>
2626
namespace stan {
2727
namespace math {
2828
namespace internal {
2929
template <typename T>
3030
struct Q_eval {
3131
T log_Q{0.0};
3232
T dlogQ_dalpha{0.0};
33-
bool ok{false};
3433
};
3534

3635
/**
3736
* Computes log q and d(log q) / d(alpha) using continued fraction.
3837
*/
39-
template <typename T, typename T_shape, bool any_fvar, bool partials_fvar>
40-
static inline Q_eval<T> eval_q_cf(const T& alpha, const T& beta_y) {
41-
Q_eval<T> out;
38+
template <bool any_fvar, bool partials_fvar, typename T_shape, typename T>
39+
static inline auto eval_q_cf(const T& alpha, const T& beta_y) {
40+
Q_eval<return_type_t<T>> out;
4241
if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
4342
auto log_q_result
4443
= log_gamma_q_dgamma(value_of_rec(alpha), value_of_rec(beta_y));
@@ -62,23 +61,23 @@ static inline Q_eval<T> eval_q_cf(const T& alpha, const T& beta_y) {
6261
}
6362
}
6463

65-
out.ok = std::isfinite(value_of_rec(out.log_Q));
66-
return out;
64+
if (std::isfinite(value_of_rec(out.log_Q))) {
65+
return std::optional<Q_eval<return_type_t<T>>>{out};
66+
} else {
67+
return std::optional<Q_eval<return_type_t<T>>>{std::nullopt};
68+
}
6769
}
6870

6971
/**
7072
* Computes log q and d(log q) / d(alpha) using log1m.
7173
*/
72-
template <typename T, typename T_shape, bool partials_fvar>
74+
template <bool partials_fvar, typename T_shape, typename T>
7375
static inline Q_eval<T> eval_q_log1m(const T& alpha, const T& beta_y) {
7476
Q_eval<T> out;
7577
out.log_Q = log1m(gamma_p(alpha, beta_y));
76-
7778
if (!std::isfinite(value_of_rec(out.log_Q))) {
78-
out.ok = false;
7979
return out;
8080
}
81-
8281
if constexpr (is_autodiff_v<T_shape>) {
8382
if constexpr (partials_fvar) {
8483
T alpha_unit = alpha;
@@ -92,8 +91,6 @@ static inline Q_eval<T> eval_q_log1m(const T& alpha, const T& beta_y) {
9291
= -grad_reg_lower_inc_gamma(alpha, beta_y) / exp(out.log_Q);
9392
}
9493
}
95-
96-
out.ok = true;
9794
return out;
9895
}
9996
} // namespace internal
@@ -154,26 +151,21 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
154151
}
155152

156153
const bool use_continued_fraction = beta_y > alpha_dbl + 1.0;
157-
internal::Q_eval<T_partials_return> result;
154+
std::optional<internal::Q_eval<T_partials_return>> result;
158155
if (use_continued_fraction) {
159-
result = internal::eval_q_cf<T_partials_return, T_shape, any_fvar,
160-
partials_fvar>(alpha_dbl, beta_y);
156+
result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(alpha_dbl, beta_y);
161157
} else {
162-
result
163-
= internal::eval_q_log1m<T_partials_return, T_shape, partials_fvar>(
164-
alpha_dbl, beta_y);
165-
166-
if (!result.ok && beta_y > 0.0) {
158+
result = internal::eval_q_log1m<partials_fvar, T_shape>(alpha_dbl, beta_y);
159+
if (!result && beta_y > 0.0) {
167160
// Fallback to continued fraction if log1m fails
168-
result = internal::eval_q_cf<T_partials_return, T_shape, any_fvar,
169-
partials_fvar>(alpha_dbl, beta_y);
161+
result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(alpha_dbl, beta_y);
170162
}
171163
}
172-
if (!result.ok) {
164+
if (!result) {
173165
return ops_partials.build(negative_infinity());
174166
}
175167

176-
P += result.log_Q;
168+
P += result->log_Q;
177169

178170
if constexpr (is_autodiff_v<T_y> || is_autodiff_v<T_inv_scale>) {
179171
const T_partials_return log_y = log(y_dbl);
@@ -183,7 +175,7 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
183175
- lgamma(alpha_dbl) + alpha_minus_one
184176
- beta_y;
185177

186-
const T_partials_return hazard = exp(log_pdf - result.log_Q); // f/Q
178+
const T_partials_return hazard = exp(log_pdf - result->log_Q); // f/Q
187179

188180
if constexpr (is_autodiff_v<T_y>) {
189181
partials<0>(ops_partials)[n] -= hazard;
@@ -193,7 +185,7 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
193185
}
194186
}
195187
if constexpr (is_autodiff_v<T_shape>) {
196-
partials<1>(ops_partials)[n] += result.dlogQ_dalpha;
188+
partials<1>(ops_partials)[n] += result->dlogQ_dalpha;
197189
}
198190
}
199191
return ops_partials.build(P);

0 commit comments

Comments
 (0)