diff --git a/R/NetStats.R b/R/NetStats.R index 54ab5fd..22cc17d 100644 --- a/R/NetStats.R +++ b/R/NetStats.R @@ -1,3 +1,54 @@ +# Parse the target_pop argument to build_netstats. Returns a list with +# $form in {"default", "list", "data.frame"} plus form-specific payload. +# Raises an informative error for the not-yet-implemented character form. +.parse_target_pop <- function(target_pop, race) { + if (is.null(target_pop)) return(list(form = "default")) + if (is.character(target_pop)) { + stop("Built-in reference target populations (target_pop = '", + target_pop, "') are not yet implemented in ARTnet. The planned ", + "set is geography-specific general male population demographics ", + "(NCHS age pyramid + ARTnetData::race.dist by geography); when ", + "those bundles ship, this argument will accept their names. ", + "For now, pass a list of marginal distributions or a data.frame. ", + "See issue #64.", call. = FALSE) + } + if (is.data.frame(target_pop)) { + req <- c("age", "deg.casl", "deg.main", "role.class", "risk.grp") + if (isTRUE(race)) req <- c(req, "race") + missing_cols <- setdiff(req, names(target_pop)) + if (length(missing_cols)) { + stop("target_pop data.frame missing required columns: ", + paste(missing_cols, collapse = ", "), ".", call. = FALSE) + } + if (nrow(target_pop) < 1) { + stop("target_pop data.frame must have at least one row.", + call. = FALSE) + } + return(list(form = "data.frame", df = target_pop)) + } + if (is.list(target_pop)) { + known <- c("age.pyramid", "race.prop", "race.props", + "deg.casl", "deg.main", "deg.tot", + "role.class", "risk.grp") + extra <- setdiff(names(target_pop), known) + if (length(extra)) { + stop("target_pop list has unknown elements: ", + paste(extra, collapse = ", "), + ". Allowed names: ", paste(known, collapse = ", "), ".", + call. = FALSE) + } + # Normalize the alternate name spelling shown in #64's example. + if (!is.null(target_pop$race.props) && is.null(target_pop$race.prop)) { + target_pop$race.prop <- target_pop$race.props + } + return(list(form = "list", overrides = target_pop)) + } + stop("target_pop must be NULL, a list, a data.frame, or a character ", + "string. Got class: ", paste(class(target_pop), collapse = ", "), ".", + call. = FALSE) +} + + # Aggregate synth-population stratum-level durations under joint_lm. Per-ego # predicted log(duration) from the joint_lm fit, marginalized over partner-race # uncertainty using joint_nm_race_model when race = TRUE, then exponentiated @@ -81,6 +132,32 @@ #' @param young.prop The proportion of the population that should be below the age of sexual cessation. #' Default is NULL (meaning no re-weighting of the `age.pyramid` parameter is performed). #' This parameter is only used if the age of sexual cessation is less than the upper age bound. +#' @param target_pop Optional specification of the synthetic target population. Defaults to NULL, +#' which uses the legacy patchwork of reference sources (NCHS age pyramid + +#' `ARTnetData::race.dist` + ARTnet's own degree / role / risk-quintile distributions). +#' Supports three forms: +#' \itemize{ +#' \item A **named list** of marginal distribution overrides. Allowed names: +#' `age.pyramid` (vector of length `nAges`), `race.prop` (length matching +#' `race.level`), `deg.casl` (length 4), `deg.main` (length 3), `deg.tot` (length 4), +#' `role.class` (length 3), `risk.grp` (length matching `nf.risk.grp`). Each +#' element overrides the corresponding default source; absent names fall through +#' to legacy defaults. Equivalent in form to passing the older `age.pyramid` / +#' `race.prop` arguments, but extends the override surface to the per-attribute +#' distributions previously sourced from `netparams`. +#' \item A **data.frame** with one row per node. Required columns: `age`, `deg.casl`, +#' `deg.main`, `role.class`, `risk.grp` (plus `race` when `epistats$race = TRUE`). +#' Optional columns: `sqrt.age`, `age.grp`, `active.sex`, `deg.tot`, `diag.status` +#' (derived from required columns if absent). When supplied, attribute sampling is +#' bypassed entirely and `network.size` is set to `nrow(target_pop)`. This form is +#' for users with a fully-specified joint synthetic population (e.g., post-stratified +#' to NHBS or AMIS demographics). +#' \item A **character string** naming a built-in reference population. Currently +#' raises an informative error. The planned set is geography-specific general male +#' population demographics (NCHS age pyramid + `ARTnetData::race.dist` by +#' geography) — bundles like `"atlanta"` or `"us_msm_male"` packaged from data +#' already in the package, no NHBS or other restricted data required. +#' } #' @param method Character. Either `"existing"` (default) or `"joint"`. `"existing"` reproduces #' the pre-refactor behavior byte-for-byte: target statistics for edges, nodefactor, and #' concurrent are computed layer-by-layer from the univariate marginal fits stored on @@ -166,12 +243,26 @@ build_netstats <- function(epistats, netparams, race.prop = NULL, young.prop = NULL, method = c("existing", "joint"), + target_pop = NULL, browser = FALSE) { method <- match.arg(method) # Ensures that ARTnetData is installed if (system.file(package = "ARTnetData") == "") stop(missing_data_msg) + # target_pop API (#64): NULL preserves the legacy patchwork-default + # behavior; a list overrides specific marginals; a data.frame supplies + # the synthetic population directly. Character-form built-in references + # are not yet implemented. + .tp <- .parse_target_pop(target_pop, race = epistats$race) + if (.tp$form == "list") { + if (!is.null(.tp$overrides$age.pyramid)) age.pyramid <- .tp$overrides$age.pyramid + if (!is.null(.tp$overrides$race.prop)) race.prop <- .tp$overrides$race.prop + } + if (.tp$form == "data.frame") { + network.size <- nrow(.tp$df) + } + if (method == "joint") { missing_joint <- vapply(c("main", "casl", "inst"), function(layer) is.null(netparams[[layer]]$joint_model), @@ -336,102 +427,153 @@ build_netstats <- function(epistats, netparams, nAges <- age.limits[2] - age.limits[1] age.vals <- age.limits[1]:(age.limits[2] - 1) out$demog$ages <- age.vals - if (!is.null(age.pyramid)) { - if (length(age.pyramid) != nAges) { - stop("Length of age.pyramid vector must be equal to length of unique age values: ", nAges) + age.breaks <- out$demog$age.breaks <- epistats$age.breaks + nquants <- length(netparams$inst$nf.risk.grp) + + # List-form distribution overrides for the sampling path. NULL means + # "use the existing default source" (current behavior). + .ov <- if (.tp$form == "list") .tp$overrides else list() + .dist_deg.casl <- if (!is.null(.ov$deg.casl)) .ov$deg.casl else netparams$main$deg.casl.dist + .dist_deg.main <- if (!is.null(.ov$deg.main)) .ov$deg.main else netparams$casl$deg.main.dist + .dist_deg.tot <- if (!is.null(.ov$deg.tot)) .ov$deg.tot else netparams$inst$deg.tot.dist + .dist_role.class <- if (!is.null(.ov$role.class)) .ov$role.class else netparams$all$role.type + .dist_risk.grp <- if (!is.null(.ov$risk.grp)) .ov$risk.grp else rep(1 / nquants, nquants) + + if (.tp$form == "data.frame") { + # ---- target_pop = data.frame: pull attributes directly from user df ---- + df <- .tp$df + attr_age <- df$age + attr_sqrt.age <- if (!is.null(df$sqrt.age)) df$sqrt.age else sqrt(attr_age) + attr_age.grp <- if (!is.null(df$age.grp)) df$age.grp else + cut(attr_age, age.breaks, labels = FALSE, right = FALSE, include.lowest = FALSE) + attr_active.sex <- if (!is.null(df$active.sex)) { + as.integer(df$active.sex) + } else { + as_active <- rep(1L, num) + if (sex.cess.mod == TRUE) { + as_active[attr_age.grp == max(attr_age.grp, na.rm = TRUE)] <- 0L + } + as_active + } + if (!is.null(df$race)) { + attr_race <- df$race + } else { + # race column is required when race = TRUE (validated in + # .parse_target_pop); otherwise sample from race.dist for parity + # with the legacy code path which always populates out$attr$race. + race_numbers <- vapply(flattened_race_level, + function(r) out$demog[[paste0("num.", r)]], numeric(1)) + attr_race <- apportion_lr(num, seq_along(flattened_race_level), + race_numbers / num, shuffled = TRUE) } + attr_deg.casl <- df$deg.casl + attr_deg.main <- df$deg.main + attr_deg.tot <- if (!is.null(df$deg.tot)) df$deg.tot else + pmin(attr_deg.main + attr_deg.casl, 3L) + attr_risk.grp <- df$risk.grp + attr_role.class <- df$role.class } else { - full.age.pyr <- c(0.01202, 0.01228, 0.01250, 0.01280, 0.01292, 0.01289, - 0.01284, 0.01286, 0.01301, 0.01297, 0.01296, 0.01337, - 0.01344, 0.01334, 0.01329, 0.01332, 0.01325, 0.01323, - 0.01360, 0.01385, 0.01365, 0.01368, 0.01373, 0.01389, - 0.01424, 0.01454, 0.01480, 0.01514, 0.01533, 0.01525, - 0.01463, 0.01426, 0.01400, 0.01401, 0.01404, 0.01352, - 0.01366, 0.01360, 0.01339, 0.01367, 0.01274, 0.01246, - 0.01227, 0.01190, 0.01226, 0.01181, 0.01192, 0.01245, - 0.01312, 0.01331, 0.01255, 0.01225, 0.01219, 0.01239, - 0.01308, 0.01321, 0.01313, 0.01303, 0.01311, 0.01320, - 0.01264, 0.01247, 0.01223, 0.01169, 0.01152, 0.01092, - 0.01042, 0.00994, 0.00954, 0.00926, 0.00890, 0.00872, - 0.00899, 0.00650, 0.00630, 0.00600, 0.00601, 0.00509, - 0.00451, 0.00414, 0.00377, 0.00346, 0.00304, 0.00275, - 0.00235, 0.00202, 0.00173, 0.00148, 0.00127, 0.00109, - 0.00093, 0.00080, 0.00068, 0.00059, 0.00050, 0.00043, - 0.00037, 0.00032, 0.00027, 0.00023) - age.pyramid <- full.age.pyr[age.vals] - } + # ---- Existing sampling path (default + list-form marginal overrides) ---- + if (!is.null(age.pyramid)) { + if (length(age.pyramid) != nAges) { + stop("Length of age.pyramid vector must be equal to length of unique age values: ", nAges) + } + } else { + full.age.pyr <- c(0.01202, 0.01228, 0.01250, 0.01280, 0.01292, 0.01289, + 0.01284, 0.01286, 0.01301, 0.01297, 0.01296, 0.01337, + 0.01344, 0.01334, 0.01329, 0.01332, 0.01325, 0.01323, + 0.01360, 0.01385, 0.01365, 0.01368, 0.01373, 0.01389, + 0.01424, 0.01454, 0.01480, 0.01514, 0.01533, 0.01525, + 0.01463, 0.01426, 0.01400, 0.01401, 0.01404, 0.01352, + 0.01366, 0.01360, 0.01339, 0.01367, 0.01274, 0.01246, + 0.01227, 0.01190, 0.01226, 0.01181, 0.01192, 0.01245, + 0.01312, 0.01331, 0.01255, 0.01225, 0.01219, 0.01239, + 0.01308, 0.01321, 0.01313, 0.01303, 0.01311, 0.01320, + 0.01264, 0.01247, 0.01223, 0.01169, 0.01152, 0.01092, + 0.01042, 0.00994, 0.00954, 0.00926, 0.00890, 0.00872, + 0.00899, 0.00650, 0.00630, 0.00600, 0.00601, 0.00509, + 0.00451, 0.00414, 0.00377, 0.00346, 0.00304, 0.00275, + 0.00235, 0.00202, 0.00173, 0.00148, 0.00127, 0.00109, + 0.00093, 0.00080, 0.00068, 0.00059, 0.00050, 0.00043, + 0.00037, 0.00032, 0.00027, 0.00023) + age.pyramid <- full.age.pyr[age.vals] + } - if (age.sexual.cessation < age.limits[2] && !is.null(young.prop)) { - age.break <- age.sexual.cessation - (age.limits[1] - 1) - age.pyramid <- reweight_age_pyr(age.pyramid, young.prop, age.break) - } + if (age.sexual.cessation < age.limits[2] && !is.null(young.prop)) { + age.break <- age.sexual.cessation - (age.limits[1] - 1) + age.pyramid <- reweight_age_pyr(age.pyramid, young.prop, age.break) + } - attr_age <- sample(x = age.vals, size = num, prob = age.pyramid, replace = TRUE) - age_noise <- runif(num) - attr_age <- attr_age + age_noise - out$attr$age <- attr_age + attr_age <- sample(x = age.vals, size = num, prob = age.pyramid, replace = TRUE) + age_noise <- runif(num) + attr_age <- attr_age + age_noise - attr_sqrt.age <- sqrt(attr_age) - out$attr$sqrt.age <- attr_sqrt.age + attr_sqrt.age <- sqrt(attr_age) - age.breaks <- out$demog$age.breaks <- epistats$age.breaks - attr_age.grp <- cut(attr_age, age.breaks, labels = FALSE, right = FALSE, include.lowest = FALSE) - out$attr$age.grp <- attr_age.grp + attr_age.grp <- cut(attr_age, age.breaks, labels = FALSE, right = FALSE, include.lowest = FALSE) - # sexually active attribute - attr_active.sex <- rep(1L, num) - if (sex.cess.mod == TRUE) { - attr_active.sex[attr_age.grp == max(attr_age.grp)] <- 0L - } - out$attr$active.sex <- attr_active.sex + # sexually active attribute + attr_active.sex <- rep(1L, num) + if (sex.cess.mod == TRUE) { + attr_active.sex[attr_age.grp == max(attr_age.grp)] <- 0L + } - # race attribute - race_numbers <- vapply( - flattened_race_level, - function(race) { - race_num_var <- paste0("num.", race) - out$demog[[race_num_var]] - }, - numeric(1) - ) + # race attribute + race_numbers <- vapply( + flattened_race_level, + function(race) { + race_num_var <- paste0("num.", race) + out$demog[[race_num_var]] + }, + numeric(1) + ) - race_proportions <- race_numbers / num - group_ids <- seq_along(flattened_race_level) - attr_race <- apportion_lr(num, group_ids, race_proportions, shuffled = TRUE) - out$attr$race <- attr_race + race_proportions <- race_numbers / num + group_ids <- seq_along(flattened_race_level) + attr_race <- apportion_lr(num, group_ids, race_proportions, shuffled = TRUE) - # deg.casl attribute - attr_deg.casl <- apportion_lr(num, 0:3, netparams$main$deg.casl.dist, shuffled = TRUE) - if (sex.cess.mod == TRUE) { - attr_deg.casl[attr_active.sex == 0] <- 0 - } - out$attr$deg.casl <- attr_deg.casl + # deg.casl attribute + attr_deg.casl <- apportion_lr(num, 0:3, .dist_deg.casl, shuffled = TRUE) + if (sex.cess.mod == TRUE) { + attr_deg.casl[attr_active.sex == 0] <- 0 + } - # deg main attribute - attr_deg.main <- apportion_lr(num, 0:2, netparams$casl$deg.main.dist, shuffled = TRUE) - if (sex.cess.mod == TRUE) { - attr_deg.main[attr_active.sex == 0] <- 0 - } - out$attr$deg.main <- attr_deg.main + # deg main attribute + attr_deg.main <- apportion_lr(num, 0:2, .dist_deg.main, shuffled = TRUE) + if (sex.cess.mod == TRUE) { + attr_deg.main[attr_active.sex == 0] <- 0 + } + + # deg tot 3 attribute + attr_deg.tot <- apportion_lr(num, 0:3, .dist_deg.tot, shuffled = TRUE) + if (sex.cess.mod == TRUE) { + attr_deg.tot[attr_active.sex == 0] <- 0 + } - # deg tot 3 attribute - attr_deg.tot <- apportion_lr(num, 0:3, netparams$inst$deg.tot.dist, shuffled = TRUE) - if (sex.cess.mod == TRUE) { - attr_deg.tot[attr_active.sex == 0] <- 0 + # risk group + attr_risk.grp <- apportion_lr(num, 1:nquants, .dist_risk.grp, shuffled = TRUE) + + # role class + attr_role.class <- apportion_lr(num, 0:2, .dist_role.class, shuffled = TRUE) } - out$attr$deg.tot <- attr_deg.tot - # risk group - nquants <- length(netparams$inst$nf.risk.grp) - attr_risk.grp <- apportion_lr(num, 1:nquants, rep(1 / nquants, nquants), shuffled = TRUE) + # Common attr assignments (both paths) ----------------------------------- + out$attr$age <- attr_age + out$attr$sqrt.age <- attr_sqrt.age + out$attr$age.grp <- attr_age.grp + out$attr$active.sex <- attr_active.sex + out$attr$race <- attr_race + out$attr$deg.casl <- attr_deg.casl + out$attr$deg.main <- attr_deg.main + out$attr$deg.tot <- attr_deg.tot out$attr$risk.grp <- attr_risk.grp - - # role class - attr_role.class <- apportion_lr(num, 0:2, netparams$all$role.type, shuffled = TRUE) out$attr$role.class <- attr_role.class # diag status - if (is.null(epistats$init.hiv.prev)) { + if (.tp$form == "data.frame" && !is.null(.tp$df$diag.status)) { + # User-supplied diag.status takes precedence over epistats-based draw. + out$attr$diag.status <- as.integer(.tp$df$diag.status) + } else if (is.null(epistats$init.hiv.prev)) { if (race == TRUE) { xs <- data.frame(age = attr_age, race.cat.num = attr_race, geogYN = 1) preds <- predict(epistats$hiv.mod, newdata = xs, type = "response") @@ -465,7 +607,9 @@ build_netstats <- function(epistats, netparams, out$attr$diag.status[attr_diag.status] <- 1 } } - out$attr$diag.status <- as.integer(out$attr$diag.status) + if (!is.integer(out$attr$diag.status)) { + out$attr$diag.status <- as.integer(out$attr$diag.status) + } # Joint g-computation predictions (method = "joint" only) ----------------- diff --git a/man/build_netstats.Rd b/man/build_netstats.Rd index 5122b99..061f5b9 100644 --- a/man/build_netstats.Rd +++ b/man/build_netstats.Rd @@ -14,6 +14,7 @@ build_netstats( race.prop = NULL, young.prop = NULL, method = c("existing", "joint"), + target_pop = NULL, browser = FALSE ) } @@ -64,6 +65,33 @@ stratum. \code{nodefactor_risk.grp} and \code{diss.homog} still use the within-A univariate / aggregated values — those are not consumed by the standard EpiModelHIV-Template dissolution offset.} +\item{target_pop}{Optional specification of the synthetic target population. Defaults to NULL, +which uses the legacy patchwork of reference sources (NCHS age pyramid + +\code{ARTnetData::race.dist} + ARTnet's own degree / role / risk-quintile distributions). +Supports three forms: +\itemize{ +\item A \strong{named list} of marginal distribution overrides. Allowed names: +\code{age.pyramid} (vector of length \code{nAges}), \code{race.prop} (length matching +\code{race.level}), \code{deg.casl} (length 4), \code{deg.main} (length 3), \code{deg.tot} (length 4), +\code{role.class} (length 3), \code{risk.grp} (length matching \code{nf.risk.grp}). Each +element overrides the corresponding default source; absent names fall through +to legacy defaults. Equivalent in form to passing the older \code{age.pyramid} / +\code{race.prop} arguments, but extends the override surface to the per-attribute +distributions previously sourced from \code{netparams}. +\item A \strong{data.frame} with one row per node. Required columns: \code{age}, \code{deg.casl}, +\code{deg.main}, \code{role.class}, \code{risk.grp} (plus \code{race} when \code{epistats$race = TRUE}). +Optional columns: \code{sqrt.age}, \code{age.grp}, \code{active.sex}, \code{deg.tot}, \code{diag.status} +(derived from required columns if absent). When supplied, attribute sampling is +bypassed entirely and \code{network.size} is set to \code{nrow(target_pop)}. This form is +for users with a fully-specified joint synthetic population (e.g., post-stratified +to NHBS or AMIS demographics). +\item A \strong{character string} naming a built-in reference population. Currently +raises an informative error. The planned set is geography-specific general male +population demographics (NCHS age pyramid + \code{ARTnetData::race.dist} by +geography) — bundles like \code{"atlanta"} or \code{"us_msm_male"} packaged from data +already in the package, no NHBS or other restricted data required. +}} + \item{browser}{If \code{TRUE}, run \code{build_netparams} in interactive browser mode.} } \description{ diff --git a/tests/testthat/test-target-pop.R b/tests/testthat/test-target-pop.R new file mode 100644 index 0000000..7b135b3 --- /dev/null +++ b/tests/testthat/test-target-pop.R @@ -0,0 +1,222 @@ +# Tests for the target_pop API on build_netstats (#64). +# Three input forms: list of marginal distributions, data.frame of nodes, +# character (not yet implemented). NULL preserves legacy behavior. + +skip_without_artnetdata <- function() { + testthat::skip_if(system.file(package = "ARTnetData") == "", + "ARTnetData not installed") +} + +setup_pipeline <- function() { + set.seed(20260420L) + ep <- build_epistats(geog.lvl = "city", geog.cat = "Atlanta", + init.hiv.prev = c(0.33, 0.137, 0.084), + race = TRUE, time.unit = 7) + set.seed(20260420L) + np <- build_netparams(ep, smooth.main.dur = TRUE) + list(epistats = ep, netparams = np) +} + +test_that("target_pop = NULL is byte-identical to no target_pop arg", { + skip_without_artnetdata() + s <- setup_pipeline() + set.seed(20260420L) + ns_default <- build_netstats(s$epistats, s$netparams, + expect.mort = 0.000478213, network.size = 2000) + set.seed(20260420L) + ns_explicit <- build_netstats(s$epistats, s$netparams, + expect.mort = 0.000478213, network.size = 2000, + target_pop = NULL) + expect_equal(ns_default$attr, ns_explicit$attr) + expect_equal(ns_default$main$edges, ns_explicit$main$edges) +}) + + +# ---- list form ------------------------------------------------------------- + +test_that("list form: race.prop override produces matching race composition", { + skip_without_artnetdata() + s <- setup_pipeline() + set.seed(20260420L) + ns <- build_netstats(s$epistats, s$netparams, + expect.mort = 0.000478213, network.size = 5000, + target_pop = list(race.prop = c(0.4, 0.2, 0.4))) + obs <- prop.table(table(ns$attr$race)) + expect_equal(as.numeric(obs), c(0.4, 0.2, 0.4), tolerance = 0.01) +}) + +test_that("list form: deg.casl override produces matching distribution", { + skip_without_artnetdata() + s <- setup_pipeline() + set.seed(20260420L) + ns <- build_netstats(s$epistats, s$netparams, + expect.mort = 0.000478213, network.size = 5000, + target_pop = list(deg.casl = c(0.5, 0.3, 0.15, 0.05))) + obs <- prop.table(table(ns$attr$deg.casl)) + expect_equal(as.numeric(obs), c(0.5, 0.3, 0.15, 0.05), tolerance = 0.005) +}) + +test_that("list form: race.props alias is normalized to race.prop", { + skip_without_artnetdata() + s <- setup_pipeline() + set.seed(20260420L) + ns <- build_netstats(s$epistats, s$netparams, + expect.mort = 0.000478213, network.size = 2000, + target_pop = list(race.props = c(0.5, 0.25, 0.25))) + obs <- prop.table(table(ns$attr$race)) + expect_equal(as.numeric(obs), c(0.5, 0.25, 0.25), tolerance = 0.02) +}) + +test_that("list form: unknown elements raise an informative error", { + skip_without_artnetdata() + s <- setup_pipeline() + expect_error( + build_netstats(s$epistats, s$netparams, + expect.mort = 0.000478213, network.size = 1000, + target_pop = list(foo = 1, age.pyramid = NULL)), + regexp = "unknown elements: foo" + ) +}) + + +# ---- data.frame form ------------------------------------------------------- + +test_that("data.frame form bypasses sampling and respects user attrs", { + skip_without_artnetdata() + s <- setup_pipeline() + set.seed(20260420L) + df <- data.frame( + age = sample(15:64, 1500, replace = TRUE), + race = sample(1:3, 1500, replace = TRUE, prob = c(0.4, 0.2, 0.4)), + deg.casl = sample(0:3, 1500, replace = TRUE), + deg.main = sample(0:2, 1500, replace = TRUE), + role.class = sample(0:2, 1500, replace = TRUE), + risk.grp = sample(1:5, 1500, replace = TRUE), + diag.status = rbinom(1500, 1, 0.15) + ) + ns <- build_netstats(s$epistats, s$netparams, + expect.mort = 0.000478213, + network.size = 99999, # must be ignored + target_pop = df) + expect_equal(ns$demog$num, 1500) + expect_length(ns$attr$age, 1500) + expect_equal(ns$attr$age, df$age) + expect_equal(ns$attr$race, df$race) + expect_equal(ns$attr$deg.casl, df$deg.casl) + expect_equal(ns$attr$diag.status, as.integer(df$diag.status)) + # Derived attrs filled in + expect_length(ns$attr$sqrt.age, 1500) + expect_length(ns$attr$age.grp, 1500) + expect_length(ns$attr$deg.tot, 1500) + expect_length(ns$attr$active.sex, 1500) +}) + +test_that("data.frame form: derived deg.tot caps at 3", { + skip_without_artnetdata() + s <- setup_pipeline() + set.seed(20260420L) + n <- 500 + df <- data.frame( + age = sample(15:64, n, replace = TRUE), + race = sample(1:3, n, replace = TRUE), + role.class = sample(0:2, n, replace = TRUE), + risk.grp = sample(1:5, n, replace = TRUE), + deg.casl = c(rep(3L, 5), sample(0:3, n - 5, replace = TRUE)), + deg.main = c(rep(2L, 5), sample(0:2, n - 5, replace = TRUE)) + ) + ns <- build_netstats(s$epistats, s$netparams, + expect.mort = 0.000478213, network.size = 100, + target_pop = df) + # First 5 rows have deg.casl = 3, deg.main = 2: raw sum 5, capped to 3. + expect_equal(ns$attr$deg.tot[1:5], rep(3L, 5)) + # All values must satisfy the cap. + expect_true(all(ns$attr$deg.tot <= 3L)) + expect_true(all(ns$attr$deg.tot >= 0L)) +}) + +test_that("data.frame form: missing required columns raise informative error", { + skip_without_artnetdata() + s <- setup_pipeline() + expect_error( + build_netstats(s$epistats, s$netparams, + expect.mort = 0.000478213, network.size = 100, + target_pop = data.frame(age = 1:5)), + regexp = "data.frame missing required columns" + ) +}) + +test_that("data.frame form: diag.status falls back to epistats when absent", { + skip_without_artnetdata() + s <- setup_pipeline() + set.seed(20260420L) + df <- data.frame( + age = sample(15:64, 500, replace = TRUE), + race = sample(1:3, 500, replace = TRUE), + deg.casl = sample(0:3, 500, replace = TRUE), + deg.main = sample(0:2, 500, replace = TRUE), + role.class = sample(0:2, 500, replace = TRUE), + risk.grp = sample(1:5, 500, replace = TRUE) + ) + ns <- build_netstats(s$epistats, s$netparams, + expect.mort = 0.000478213, network.size = 100, + target_pop = df) + expect_length(ns$attr$diag.status, 500) + expect_true(all(ns$attr$diag.status %in% c(0L, 1L))) +}) + +test_that("data.frame form: composes with method = 'joint'", { + skip_without_artnetdata() + set.seed(20260420L) + ep <- build_epistats(geog.lvl = "city", geog.cat = "Atlanta", + init.hiv.prev = c(0.33, 0.137, 0.084), + race = TRUE, time.unit = 7) + set.seed(20260420L) + np <- build_netparams(ep, smooth.main.dur = TRUE, + method = "joint", duration.method = "joint_lm") + df <- data.frame( + age = sample(15:64, 800, replace = TRUE), + race = sample(1:3, 800, replace = TRUE), + deg.casl = sample(0:3, 800, replace = TRUE), + deg.main = sample(0:2, 800, replace = TRUE), + role.class = sample(0:2, 800, replace = TRUE), + risk.grp = sample(1:5, 800, replace = TRUE), + diag.status = rbinom(800, 1, 0.15) + ) + ns <- build_netstats(ep, np, + expect.mort = 0.000478213, network.size = 100, + target_pop = df, method = "joint") + # Internal consistency under joint must still hold + expect_equal(sum(ns$main$nodefactor_race), 2 * ns$main$edges, + tolerance = 1e-9) +}) + + +# ---- character form -------------------------------------------------------- + +test_that("character form raises informative not-yet-implemented error", { + skip_without_artnetdata() + s <- setup_pipeline() + # The planned named set is geography-specific general male population + # demographics (NCHS age pyramid + ARTnetData::race.dist), e.g. "atlanta". + # Currently any character string raises a clear not-yet-implemented error. + expect_error( + build_netstats(s$epistats, s$netparams, + expect.mort = 0.000478213, network.size = 1000, + target_pop = "atlanta"), + regexp = "not yet implemented" + ) +}) + + +# ---- bad input ------------------------------------------------------------- + +test_that("non-list non-data.frame non-character input raises error", { + skip_without_artnetdata() + s <- setup_pipeline() + expect_error( + build_netstats(s$epistats, s$netparams, + expect.mort = 0.000478213, network.size = 1000, + target_pop = 42), + regexp = "must be NULL, a list, a data.frame" + ) +})