|
26 | 26 | namespace stan { |
27 | 27 | namespace math { |
28 | 28 | namespace internal { |
29 | | -template <typename T> |
30 | | -struct Q_eval { |
31 | | - T log_Q{0.0}; |
32 | | - T dlogQ_dalpha{0.0}; |
33 | | -}; |
34 | 29 |
|
35 | 30 | /** |
36 | 31 | * Computes log q and d(log q) / d(alpha) using continued fraction. |
37 | 32 | */ |
38 | | -template <bool any_fvar, bool partials_fvar, typename T_shape, typename T> |
39 | | -static inline auto eval_q_cf(const T& alpha, const T& beta_y) { |
40 | | - Q_eval<return_type_t<T>> out; |
| 33 | +template <bool any_fvar, bool partials_fvar, typename T_shape, typename T1, typename T2> |
| 34 | +inline auto eval_q_cf(const T1& alpha, const T2& beta_y) { |
| 35 | + using scalar_t = return_type_t<T1, T2>; |
| 36 | + using ret_t = std::pair<scalar_t, scalar_t>; |
41 | 37 | if constexpr (!any_fvar && is_autodiff_v<T_shape>) { |
42 | 38 | auto log_q_result |
43 | 39 | = log_gamma_q_dgamma(value_of_rec(alpha), value_of_rec(beta_y)); |
44 | | - out.log_Q = log_q_result.log_q; |
45 | | - out.dlogQ_dalpha = log_q_result.dlog_q_da; |
| 40 | + if (likely(std::isfinite(value_of_rec(log_q_result.first)))) { |
| 41 | + return std::optional{log_q_result}; |
| 42 | + } else { |
| 43 | + return std::optional<ret_t>{std::nullopt}; |
| 44 | + } |
46 | 45 | } else { |
47 | | - out.log_Q = internal::log_q_gamma_cf(alpha, beta_y); |
| 46 | + ret_t out{internal::log_q_gamma_cf(alpha, beta_y), 0.0}; |
| 47 | + if (unlikely(!std::isfinite(value_of_rec(out.first)))) { |
| 48 | + return std::optional<ret_t>{std::nullopt}; |
| 49 | + } |
48 | 50 | if constexpr (is_autodiff_v<T_shape>) { |
49 | 51 | if constexpr (!partials_fvar) { |
50 | | - out.dlogQ_dalpha |
| 52 | + out.second |
51 | 53 | = grad_reg_inc_gamma(alpha, beta_y, tgamma(alpha), digamma(alpha)) |
52 | | - / exp(out.log_Q); |
| 54 | + / exp(out.first); |
53 | 55 | } else { |
54 | | - T alpha_unit = alpha; |
| 56 | + auto alpha_unit = alpha; |
55 | 57 | alpha_unit.d_ = 1; |
56 | | - T beta_y_unit = beta_y; |
| 58 | + auto beta_y_unit = beta_y; |
57 | 59 | beta_y_unit.d_ = 0; |
58 | | - T log_Q_fvar = internal::log_q_gamma_cf(alpha_unit, beta_y_unit); |
59 | | - out.dlogQ_dalpha = log_Q_fvar.d_; |
| 60 | + auto log_Q_fvar = internal::log_q_gamma_cf(alpha_unit, beta_y_unit); |
| 61 | + out.second = log_Q_fvar.d_; |
60 | 62 | } |
61 | 63 | } |
62 | | - } |
63 | | - |
64 | | - if (std::isfinite(value_of_rec(out.log_Q))) { |
65 | | - return std::optional<Q_eval<return_type_t<T>>>{out}; |
66 | | - } else { |
67 | | - return std::optional<Q_eval<return_type_t<T>>>{std::nullopt}; |
| 64 | + return std::optional{out}; |
68 | 65 | } |
69 | 66 | } |
70 | 67 |
|
71 | 68 | /** |
72 | 69 | * Computes log q and d(log q) / d(alpha) using log1m. |
73 | 70 | */ |
74 | | -template <bool partials_fvar, typename T_shape, typename T> |
75 | | -static inline Q_eval<T> eval_q_log1m(const T& alpha, const T& beta_y) { |
76 | | - Q_eval<T> out; |
77 | | - out.log_Q = log1m(gamma_p(alpha, beta_y)); |
78 | | - if (!std::isfinite(value_of_rec(out.log_Q))) { |
79 | | - return out; |
| 71 | +template <bool partials_fvar, typename T_shape, typename T1, typename T2> |
| 72 | +inline auto eval_q_log1m(const T1& alpha, const T2& beta_y) { |
| 73 | + using scalar_t = return_type_t<T1, T2>; |
| 74 | + using ret_t = std::pair<scalar_t, scalar_t>; |
| 75 | + ret_t out{log1m(gamma_p(alpha, beta_y)), 0.0}; |
| 76 | + if (unlikely(!std::isfinite(value_of_rec(out.first)))) { |
| 77 | + return std::optional<ret_t>{std::nullopt}; |
80 | 78 | } |
81 | 79 | if constexpr (is_autodiff_v<T_shape>) { |
82 | 80 | if constexpr (partials_fvar) { |
83 | | - T alpha_unit = alpha; |
| 81 | + auto alpha_unit = alpha; |
84 | 82 | alpha_unit.d_ = 1; |
85 | | - T beta_unit = beta_y; |
| 83 | + auto beta_unit = beta_y; |
86 | 84 | beta_unit.d_ = 0; |
87 | | - T log_Q_fvar = log1m(gamma_p(alpha_unit, beta_unit)); |
88 | | - out.dlogQ_dalpha = log_Q_fvar.d_; |
| 85 | + auto log_Q_fvar = log1m(gamma_p(alpha_unit, beta_unit)); |
| 86 | + out.second = log_Q_fvar.d_; |
89 | 87 | } else { |
90 | | - out.dlogQ_dalpha |
91 | | - = -grad_reg_lower_inc_gamma(alpha, beta_y) / exp(out.log_Q); |
| 88 | + out.second |
| 89 | + = -grad_reg_lower_inc_gamma(alpha, beta_y) / exp(out.first); |
92 | 90 | } |
93 | 91 | } |
94 | | - return out; |
| 92 | + return std::optional{out}; |
95 | 93 | } |
96 | 94 | } // namespace internal |
97 | 95 |
|
@@ -134,58 +132,56 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf( |
134 | 132 | for (size_t n = 0; n < N; n++) { |
135 | 133 | // Explicit results for extreme values |
136 | 134 | // The gradients are technically ill-defined, but treated as zero |
137 | | - const T_partials_return y_dbl = y_vec.val(n); |
138 | | - if (y_dbl == 0.0) { |
| 135 | + const auto y_val = y_vec.val(n); |
| 136 | + if (y_val == 0.0) { |
139 | 137 | continue; |
140 | 138 | } |
141 | | - if (y_dbl == INFTY) { |
| 139 | + if (y_val == INFTY) { |
142 | 140 | return ops_partials.build(negative_infinity()); |
143 | 141 | } |
144 | 142 |
|
145 | | - const T_partials_return alpha_dbl = alpha_vec.val(n); |
146 | | - const T_partials_return beta_dbl = beta_vec.val(n); |
| 143 | + const auto alpha_val = alpha_vec.val(n); |
| 144 | + const auto beta_val = beta_vec.val(n); |
147 | 145 |
|
148 | | - const T_partials_return beta_y = beta_dbl * y_dbl; |
| 146 | + const auto beta_y = beta_val * y_val; |
149 | 147 | if (beta_y == INFTY) { |
150 | 148 | return ops_partials.build(negative_infinity()); |
151 | 149 | } |
152 | | - |
153 | | - const bool use_continued_fraction = beta_y > alpha_dbl + 1.0; |
154 | | - std::optional<internal::Q_eval<T_partials_return>> result; |
155 | | - if (use_continued_fraction) { |
156 | | - result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(alpha_dbl, beta_y); |
| 150 | + std::optional<std::pair<T_partials_return, T_partials_return>> result; |
| 151 | + if (beta_y > alpha_val + 1.0) { |
| 152 | + result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(alpha_val, beta_y); |
157 | 153 | } else { |
158 | | - result = internal::eval_q_log1m<partials_fvar, T_shape>(alpha_dbl, beta_y); |
| 154 | + result = internal::eval_q_log1m<partials_fvar, T_shape>(alpha_val, beta_y); |
159 | 155 | if (!result && beta_y > 0.0) { |
160 | 156 | // Fallback to continued fraction if log1m fails |
161 | | - result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(alpha_dbl, beta_y); |
| 157 | + result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(alpha_val, beta_y); |
162 | 158 | } |
163 | 159 | } |
164 | | - if (!result) { |
| 160 | + if (unlikely(!result)) { |
165 | 161 | return ops_partials.build(negative_infinity()); |
166 | 162 | } |
167 | 163 |
|
168 | | - P += result->log_Q; |
| 164 | + P += result->first; |
169 | 165 |
|
170 | 166 | if constexpr (is_autodiff_v<T_y> || is_autodiff_v<T_inv_scale>) { |
171 | | - const T_partials_return log_y = log(y_dbl); |
172 | | - const T_partials_return alpha_minus_one = fma(alpha_dbl, log_y, -log_y); |
| 167 | + const auto log_y = log(y_val); |
| 168 | + const auto alpha_minus_one = fma(alpha_val, log_y, -log_y); |
173 | 169 |
|
174 | | - const T_partials_return log_pdf = alpha_dbl * log(beta_dbl) |
175 | | - - lgamma(alpha_dbl) + alpha_minus_one |
| 170 | + const auto log_pdf = alpha_val * log(beta_val) |
| 171 | + - lgamma(alpha_val) + alpha_minus_one |
176 | 172 | - beta_y; |
177 | 173 |
|
178 | | - const T_partials_return hazard = exp(log_pdf - result->log_Q); // f/Q |
| 174 | + const auto hazard = exp(log_pdf - result->first); // f/Q |
179 | 175 |
|
180 | 176 | if constexpr (is_autodiff_v<T_y>) { |
181 | 177 | partials<0>(ops_partials)[n] -= hazard; |
182 | 178 | } |
183 | 179 | if constexpr (is_autodiff_v<T_inv_scale>) { |
184 | | - partials<2>(ops_partials)[n] -= (y_dbl / beta_dbl) * hazard; |
| 180 | + partials<2>(ops_partials)[n] -= (y_val / beta_val) * hazard; |
185 | 181 | } |
186 | 182 | } |
187 | 183 | if constexpr (is_autodiff_v<T_shape>) { |
188 | | - partials<1>(ops_partials)[n] += result->dlogQ_dalpha; |
| 184 | + partials<1>(ops_partials)[n] += result->second; |
189 | 185 | } |
190 | 186 | } |
191 | 187 | return ops_partials.build(P); |
|
0 commit comments