2222#include < stan/math/prim/fun/log_gamma_q_dgamma.hpp>
2323#include < stan/math/prim/functor/partials_propagator.hpp>
2424#include < cmath>
25-
25+ # include < optional >
2626namespace stan {
2727namespace math {
2828namespace internal {
2929template <typename T>
3030struct Q_eval {
3131 T log_Q{0.0 };
3232 T dlogQ_dalpha{0.0 };
33- bool ok{false };
3433};
3534
3635/* *
3736 * Computes log q and d(log q) / d(alpha) using continued fraction.
3837 */
39- template <typename T, typename T_shape, bool any_fvar, bool partials_fvar >
40- static inline Q_eval<T> eval_q_cf (const T& alpha, const T& beta_y) {
41- Q_eval<T > out;
38+ template <bool any_fvar, bool partials_fvar, typename T_shape, typename T >
39+ static inline auto eval_q_cf (const T& alpha, const T& beta_y) {
40+ Q_eval<return_type_t <T> > out;
4241 if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
4342 auto log_q_result
4443 = log_gamma_q_dgamma (value_of_rec (alpha), value_of_rec (beta_y));
@@ -62,23 +61,23 @@ static inline Q_eval<T> eval_q_cf(const T& alpha, const T& beta_y) {
6261 }
6362 }
6463
65- out.ok = std::isfinite (value_of_rec (out.log_Q ));
66- return out;
64+ if (std::isfinite (value_of_rec (out.log_Q ))) {
65+ return std::optional<Q_eval<return_type_t <T>>>{out};
66+ } else {
67+ return std::optional<Q_eval<return_type_t <T>>>{std::nullopt };
68+ }
6769}
6870
6971/* *
7072 * Computes log q and d(log q) / d(alpha) using log1m.
7173 */
72- template <typename T , typename T_shape, bool partials_fvar >
74+ template <bool partials_fvar , typename T_shape, typename T >
7375static inline Q_eval<T> eval_q_log1m (const T& alpha, const T& beta_y) {
7476 Q_eval<T> out;
7577 out.log_Q = log1m (gamma_p (alpha, beta_y));
76-
7778 if (!std::isfinite (value_of_rec (out.log_Q ))) {
78- out.ok = false ;
7979 return out;
8080 }
81-
8281 if constexpr (is_autodiff_v<T_shape>) {
8382 if constexpr (partials_fvar) {
8483 T alpha_unit = alpha;
@@ -92,8 +91,6 @@ static inline Q_eval<T> eval_q_log1m(const T& alpha, const T& beta_y) {
9291 = -grad_reg_lower_inc_gamma (alpha, beta_y) / exp (out.log_Q );
9392 }
9493 }
95-
96- out.ok = true ;
9794 return out;
9895}
9996} // namespace internal
@@ -154,26 +151,21 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
154151 }
155152
156153 const bool use_continued_fraction = beta_y > alpha_dbl + 1.0 ;
157- internal::Q_eval<T_partials_return> result;
154+ std::optional< internal::Q_eval<T_partials_return> > result;
158155 if (use_continued_fraction) {
159- result = internal::eval_q_cf<T_partials_return, T_shape, any_fvar,
160- partials_fvar>(alpha_dbl, beta_y);
156+ result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(alpha_dbl, beta_y);
161157 } else {
162- result
163- = internal::eval_q_log1m<T_partials_return, T_shape, partials_fvar>(
164- alpha_dbl, beta_y);
165-
166- if (!result.ok && beta_y > 0.0 ) {
158+ result = internal::eval_q_log1m<partials_fvar, T_shape>(alpha_dbl, beta_y);
159+ if (!result && beta_y > 0.0 ) {
167160 // Fallback to continued fraction if log1m fails
168- result = internal::eval_q_cf<T_partials_return, T_shape, any_fvar,
169- partials_fvar>(alpha_dbl, beta_y);
161+ result = internal::eval_q_cf<any_fvar, partials_fvar, T_shape>(alpha_dbl, beta_y);
170162 }
171163 }
172- if (!result. ok ) {
164+ if (!result) {
173165 return ops_partials.build (negative_infinity ());
174166 }
175167
176- P += result. log_Q ;
168+ P += result-> log_Q ;
177169
178170 if constexpr (is_autodiff_v<T_y> || is_autodiff_v<T_inv_scale>) {
179171 const T_partials_return log_y = log (y_dbl);
@@ -183,7 +175,7 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
183175 - lgamma (alpha_dbl) + alpha_minus_one
184176 - beta_y;
185177
186- const T_partials_return hazard = exp (log_pdf - result. log_Q ); // f/Q
178+ const T_partials_return hazard = exp (log_pdf - result-> log_Q ); // f/Q
187179
188180 if constexpr (is_autodiff_v<T_y>) {
189181 partials<0 >(ops_partials)[n] -= hazard;
@@ -193,7 +185,7 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
193185 }
194186 }
195187 if constexpr (is_autodiff_v<T_shape>) {
196- partials<1 >(ops_partials)[n] += result. dlogQ_dalpha ;
188+ partials<1 >(ops_partials)[n] += result-> dlogQ_dalpha ;
197189 }
198190 }
199191 return ops_partials.build (P);
0 commit comments