Skip to content

Commit cbde4f0

Browse files
Accept posterior::draws objects in PPC/PPD functions (#542)
validate_predictions() now coerces posterior::draws objects (draws_matrix, draws_array, draws_df, draws_rvars) to a plain matrix before validation, matching the pattern already used by prepare_mcmc_array() on the MCMC side. This lets users pass draws objects directly to ppc_* and ppd_* functions without manually calling as.matrix() first.
1 parent 44c3e5c commit cbde4f0

2 files changed

Lines changed: 57 additions & 3 deletions

File tree

R/helpers-ppc.R

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,20 @@ validate_y <- function(y) {
5858
#' Validate predictions (`yrep` or `ypred`)
5959
#'
6060
#' Checks that `predictions` is a numeric matrix, doesn't have any NAs, and has
61-
#' the correct number of columns.
61+
#' the correct number of columns. If `predictions` is a `posterior::draws`
62+
#' object it is first coerced to a matrix.
6263
#'
63-
#' @param predictions The user's `yrep` or `ypred` object (SxN matrix).
64+
#' @param predictions The user's `yrep` or `ypred` object (SxN matrix or a
65+
#' `posterior::draws` object).
6466
#' @param `n_obs` The number of observations (columns) that `predictions` should
6567
#' have, if applicable.
6668
#' @return Either throws an error or returns a numeric matrix.
6769
#' @noRd
6870
validate_predictions <- function(predictions, n_obs = NULL) {
69-
# sanity checks
71+
if (posterior::is_draws(predictions)) {
72+
predictions <- posterior::as_draws_matrix(predictions)
73+
predictions <- unclass(predictions)
74+
}
7075
stopifnot(is.matrix(predictions), is.numeric(predictions))
7176
if (!is.null(n_obs)) {
7277
stopifnot(length(n_obs) == 1, n_obs == as.integer(n_obs))

tests/testthat/test-helpers-ppc.R

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
library(posterior)
12
source(test_path("data-for-ppc-tests.R"))
23
source(test_path("data-for-mcmc-tests.R"))
34

@@ -113,6 +114,54 @@ test_that("get_interpolation_values catches impossible values", {
113114
)
114115
})
115116

117+
# validate_predictions with posterior::draws objects ----------------------
118+
test_that("validate_predictions accepts draws_matrix", {
119+
dm <- posterior::as_draws_matrix(yrep)
120+
result <- validate_predictions(dm, ncol(yrep))
121+
expect_true(is.matrix(result))
122+
expect_equal(dim(result), dim(yrep))
123+
expect_true(is.numeric(result))
124+
})
125+
126+
test_that("validate_predictions accepts draws_array", {
127+
da <- posterior::as_draws_array(yrep)
128+
result <- validate_predictions(da)
129+
expect_true(is.matrix(result))
130+
expect_true(is.numeric(result))
131+
})
132+
133+
test_that("validate_predictions accepts draws_df", {
134+
ddf <- posterior::as_draws_df(yrep)
135+
result <- validate_predictions(ddf)
136+
expect_true(is.matrix(result))
137+
expect_true(is.numeric(result))
138+
})
139+
140+
test_that("validate_predictions accepts draws_rvars", {
141+
dr <- posterior::as_draws_rvars(yrep)
142+
result <- validate_predictions(dr)
143+
expect_true(is.matrix(result))
144+
expect_true(is.numeric(result))
145+
})
146+
147+
test_that("ppc_dens_overlay works with draws_matrix input", {
148+
dm <- posterior::as_draws_matrix(yrep)
149+
p <- ppc_dens_overlay(y, dm)
150+
expect_s3_class(p, "ggplot")
151+
})
152+
153+
test_that("ppc_hist works with draws_matrix input", {
154+
dm <- posterior::as_draws_matrix(yrep)
155+
p <- ppc_hist(y, dm[1:8, ])
156+
expect_s3_class(p, "ggplot")
157+
})
158+
159+
test_that("ppc_stat works with draws_matrix input", {
160+
dm <- posterior::as_draws_matrix(yrep)
161+
p <- ppc_stat(y, dm)
162+
expect_s3_class(p, "ggplot")
163+
})
164+
116165
# ecdf_intervals ---------------------------------------------------------
117166
test_that("ecdf_intervals returns right dimensions and values", {
118167
lims <- ecdf_intervals(.0001, N = 100, K = 100, L = 1)

0 commit comments

Comments
 (0)