Skip to content

Commit 7724011

Browse files
SteveBrondersyclik
authored andcommitted
update to fix value_of and value_of_rec in gamma functions
1 parent 6e1dc0a commit 7724011

2 files changed

Lines changed: 27 additions & 13 deletions

File tree

stan/math/prim/fun/log_gamma_q_dgamma.hpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,16 @@
1515
#include <stan/math/prim/fun/log1m.hpp>
1616
#include <stan/math/prim/fun/tgamma.hpp>
1717
#include <stan/math/prim/fun/value_of.hpp>
18+
#include <stan/math/prim/fun/value_of_rec.hpp>
1819
#include <cmath>
1920

2021
namespace stan {
2122
namespace math {
2223

2324
namespace internal {
2425

26+
constexpr double LOG_Q_GAMMA_CF_PRECISION = 1.49012e-12;
27+
2528
/**
2629
* Compute log(Q(a,z)) using continued fraction expansion for upper incomplete
2730
* gamma function.
@@ -36,26 +39,33 @@ namespace internal {
3639
*/
3740
template <typename T_a, typename T_z>
3841
inline return_type_t<T_a, T_z> log_q_gamma_cf(const T_a& a, const T_z& z,
39-
double precision = 1.49012e-08,
42+
double precision
43+
= LOG_Q_GAMMA_CF_PRECISION,
4044
int max_steps = 250) {
4145
using T_return = return_type_t<T_a, T_z>;
4246
const T_return log_prefactor = a * log(z) - z - lgamma(a);
4347

4448
T_return b_init = z + 1.0 - a;
45-
T_return C = (fabs(value_of(b_init)) >= EPSILON) ? b_init : std::decay_t<decltype(b_init)>(EPSILON);
49+
T_return C = (fabs(value_of_rec(b_init)) >= EPSILON)
50+
? b_init
51+
: std::decay_t<decltype(b_init)>(EPSILON);
4652
T_return D = 0.0;
4753
T_return f = C;
4854
for (int i = 1; i <= max_steps; ++i) {
4955
T_a an = -i * (i - a);
5056
const T_return b = b_init + 2.0 * i;
5157
D = b + an * D;
52-
D = (fabs(value_of(D)) >= EPSILON) ? D : std::decay_t<decltype(D)>(EPSILON);
58+
D = (fabs(value_of_rec(D)) >= EPSILON)
59+
? D
60+
: std::decay_t<decltype(D)>(EPSILON);
5361
C = b + an / C;
54-
C = (fabs(value_of(C)) >= EPSILON) ? C : std::decay_t<decltype(C)>(EPSILON);
62+
C = (fabs(value_of_rec(C)) >= EPSILON)
63+
? C
64+
: std::decay_t<decltype(C)>(EPSILON);
5565
D = inv(D);
5666
const T_return delta = C * D;
5767
f *= delta;
58-
const double delta_m1 = fabs(value_of(delta) - 1.0);
68+
const double delta_m1 = fabs(value_of_rec(delta) - 1.0);
5969
if (delta_m1 < precision) {
6070
break;
6171
}
@@ -84,13 +94,16 @@ inline return_type_t<T_a, T_z> log_q_gamma_cf(const T_a& a, const T_z& z,
8494
*/
8595
template <typename T_a, typename T_z>
8696
inline std::pair<return_type_t<T_a, T_z>, return_type_t<T_a, T_z>> log_gamma_q_dgamma(
87-
const T_a& a, const T_z& z, double precision = 1.49012e-08, int max_steps = 250) {
97+
const T_a& a, const T_z& z,
98+
double precision = internal::LOG_Q_GAMMA_CF_PRECISION,
99+
int max_steps = 250) {
88100
using T_return = return_type_t<T_a, T_z>;
89101
const double a_val = value_of(a);
90102
const double z_val = value_of(z);
91103
// For z > a + 1, use continued fraction for better numerical stability
92104
if (z_val > a_val + 1.0) {
93-
std::pair<T_return, T_return> result{internal::log_q_gamma_cf(a_val, z_val, precision, max_steps), 0.0};
105+
std::pair<T_return, T_return> result{
106+
internal::log_q_gamma_cf(a_val, z_val, precision, max_steps), 0.0};
94107
// For gradient, use: d/da log(Q) = (1/Q) * dQ/da
95108
// grad_reg_inc_gamma computes dQ/da
96109
const T_return Q_val = exp(result.first);

stan/math/prim/prob/gamma_lccdf.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef STAN_MATH_PRIM_PROB_GAMMA_LCCDF_HPP
22
#define STAN_MATH_PRIM_PROB_GAMMA_LCCDF_HPP
33

4+
#include <stan/math/fwd/meta/is_fvar.hpp>
45
#include <stan/math/prim/meta.hpp>
56
#include <stan/math/prim/err.hpp>
67
#include <stan/math/prim/fun/constants.hpp>
@@ -38,8 +39,8 @@ eval_q_cf(const T1& alpha, const T2& beta_y) {
3839
using ret_t = std::pair<scalar_t, scalar_t>;
3940
if constexpr (!any_fvar && is_autodiff_v<T_shape>) {
4041
std::pair<double, double> log_q_result
41-
= log_gamma_q_dgamma(value_of_rec(alpha), value_of_rec(beta_y));
42-
if (likely(std::isfinite(value_of_rec(log_q_result.first)))) {
42+
= log_gamma_q_dgamma(value_of(alpha), value_of(beta_y));
43+
if (likely(std::isfinite(log_q_result.first))) {
4344
return std::optional{log_q_result};
4445
} else {
4546
return std::optional<ret_t>{std::nullopt};
@@ -127,10 +128,10 @@ inline return_type_t<T_y, T_shape, T_inv_scale> gamma_lccdf(
127128
scalar_seq_view<T_beta_ref> beta_vec(beta_ref);
128129
const size_t N = max_size(y, alpha, beta);
129130

130-
constexpr bool any_fvar = is_fvar<scalar_type_t<T_y>>::value
131-
|| is_fvar<scalar_type_t<T_shape>>::value
132-
|| is_fvar<scalar_type_t<T_inv_scale>>::value;
133-
constexpr bool partials_fvar = is_fvar<T_partials_return>::value;
131+
constexpr bool any_fvar = is_fvar_v<scalar_type_t<T_y>>
132+
|| is_fvar_v<scalar_type_t<T_shape>>
133+
|| is_fvar_v<scalar_type_t<T_inv_scale>>;
134+
constexpr bool partials_fvar = is_fvar_v<T_partials_return>;
134135

135136
for (size_t n = 0; n < N; n++) {
136137
// Explicit results for extreme values

0 commit comments

Comments
 (0)