Skip to content

Commit 7d5ffd7

Browse files
authored
Merge pull request #504 from utkarshpawade/fix/kde-correction-empty-vector
Validate user-provided pit at the entry point
2 parents d1e452e + 29820db commit 7d5ffd7

3 files changed

Lines changed: 53 additions & 21 deletions

File tree

NEWS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# bayesplot (development version)
22

3-
3+
* Validate user-provided `pit` values in `ppc_loo_pit_data()` and `ppc_loo_pit_qq()`, rejecting non-numeric inputs, missing values, and values outside `[0, 1]`.
44
* New `show_marginal` argument to `ppd_*()` functions to show the PPD - the marginal predictive distribution by @mattansb (#425)
55
* `ppc_ecdf_overlay()`, `ppc_ecdf_overlay_grouped()`, and `ppd_ecdf_overlay()` now always use `geom_step()`. The `discrete` argument is deprecated.
66
* Fixed missing `drop = FALSE` in `nuts_params.CmdStanMCMC()`.

R/ppc-loo.R

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,10 @@ ppc_loo_pit_data <-
302302
boundary_correction = TRUE,
303303
grid_len = 512) {
304304
if (!is.null(pit)) {
305-
stopifnot(is.numeric(pit), is_vector_or_1Darray(pit))
305+
pit <- validate_pit(pit)
306+
if (boundary_correction && length(pit) < 2L) {
307+
abort("At least 2 PIT values are required when 'boundary_correction' is TRUE.")
308+
}
306309
inform("'pit' specified so ignoring 'y','yrep','lw' if specified.")
307310
} else {
308311
suggested_package("rstantools")
@@ -348,7 +351,7 @@ ppc_loo_pit_qq <- function(y,
348351

349352
compare <- match.arg(compare)
350353
if (!is.null(pit)) {
351-
stopifnot(is.numeric(pit), is_vector_or_1Darray(pit))
354+
pit <- validate_pit(pit)
352355
inform("'pit' specified so ignoring 'y','yrep','lw' if specified.")
353356
} else {
354357
suggested_package("rstantools")
@@ -795,14 +798,6 @@ ppc_loo_ribbon <-
795798
# Generate boundary corrected values via a linear convolution using a
796799
# 1-D Gaussian window filter. This method uses the "reflection method"
797800
# to estimate these pvalues and helps speed up the code
798-
if (any(is.infinite(x))) {
799-
warn(paste(
800-
"Ignored", sum(is.infinite(x)),
801-
"Non-finite PIT values are invalid for KDE boundary correction method"
802-
))
803-
x <- x[is.finite(x)]
804-
}
805-
806801
if (grid_len < 100) {
807802
grid_len <- 100
808803
}
@@ -819,6 +814,10 @@ ppc_loo_ribbon <-
819814
# 1-D Convolution
820815
bc_pvals <- .linear_convolution(x, bw, grid_counts, grid_breaks, grid_len)
821816

817+
if (all(is.na(bc_pvals))) {
818+
abort("KDE boundary correction produced all NA values.")
819+
}
820+
822821
# Generate vector of x-axis values for plotting based on binned relative freqs
823822
n_breaks <- length(grid_breaks)
824823

tests/testthat/test-ppc-loo.R

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,17 +59,50 @@ test_that("ppc_loo_pit_overlay works with boundary_correction=FALSE", {
5959
expect_gg(p1)
6060
})
6161

62-
test_that(".kde_correction warns when PIT values are non-finite", {
63-
set.seed(123)
64-
pit_vals <- c(stats::runif(500), Inf)
65-
expect_warning(
66-
out <- .kde_correction(pit_vals, bw = "nrd0", grid_len = 128),
67-
"Non-finite PIT values are invalid"
62+
test_that("ppc_loo_pit_data validates user-provided pit values", {
63+
expect_error(
64+
ppc_loo_pit_data(pit = c(0.5, Inf)),
65+
"between 0 and 1"
66+
)
67+
expect_error(
68+
ppc_loo_pit_data(pit = c(-1, 0.5)),
69+
"between 0 and 1"
70+
)
71+
expect_error(
72+
ppc_loo_pit_data(pit = c(0.5, NA)),
73+
"NAs not allowed"
74+
)
75+
expect_error(
76+
ppc_loo_pit_data(pit = "not numeric"),
77+
"is.numeric"
78+
)
79+
expect_error(
80+
ppc_loo_pit_data(pit = c(Inf, -Inf, Inf)),
81+
"between 0 and 1"
82+
)
83+
expect_error(
84+
ppc_loo_pit_data(pit = 0.5, boundary_correction = TRUE),
85+
"At least 2 PIT values"
86+
)
87+
})
88+
89+
test_that("ppc_loo_pit_qq validates user-provided pit values", {
90+
expect_error(
91+
ppc_loo_pit_qq(pit = c(0.5, Inf)),
92+
"between 0 and 1"
93+
)
94+
expect_error(
95+
ppc_loo_pit_qq(pit = c(-1, 0.5)),
96+
"between 0 and 1"
97+
)
98+
expect_error(
99+
ppc_loo_pit_qq(pit = c(0.5, NA)),
100+
"NAs not allowed"
101+
)
102+
expect_error(
103+
ppc_loo_pit_qq(pit = "not numeric"),
104+
"is.numeric"
68105
)
69-
expect_type(out, "list")
70-
expect_true(all(c("xs", "bc_pvals") %in% names(out)))
71-
expect_equal(length(out$xs), 128)
72-
expect_equal(length(out$bc_pvals), 128)
73106
})
74107

75108
test_that("ppc_loo_pit_qq returns ggplot object", {

0 commit comments

Comments
 (0)