From aa2e6af0ce1f5196efa9fb8e1999e0f6f0307975 Mon Sep 17 00:00:00 2001 From: ishaan-arora-1 Date: Tue, 10 Mar 2026 17:16:31 +0530 Subject: [PATCH] Add unified wrapper functions for bayesplot plot families Introduce high-level wrapper functions that consolidate related plotting functions into single entry points with a `type` argument: - ppc_error(), ppc_dist(), ppc_discrete(), ppc_loo() - ppd_dist() - mcmc_dist(), mcmc_trace_w(), mcmc_diag(), mcmc_nuts(), mcmc_recover() Each wrapper dispatches to the underlying function (e.g., ppc_dist(type = "hist") calls ppc_hist()). Wrappers that have grouped variants accept a `grouped` flag. Fixes #433 --- NAMESPACE | 10 ++ NEWS.md | 4 + R/wrappers.R | 273 +++++++++++++++++++++++++++++++++ man/bayesplot-wrappers.Rd | 195 +++++++++++++++++++++++ tests/testthat/test-wrappers.R | 123 +++++++++++++++ 5 files changed, 605 insertions(+) create mode 100644 R/wrappers.R create mode 100644 man/bayesplot-wrappers.Rd create mode 100644 tests/testthat/test-wrappers.R diff --git a/NAMESPACE b/NAMESPACE index b8cf818b..2951a3ed 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -74,6 +74,8 @@ export(mcmc_dens) export(mcmc_dens_chains) export(mcmc_dens_chains_data) export(mcmc_dens_overlay) +export(mcmc_diag) +export(mcmc_dist) export(mcmc_dots) export(mcmc_dots_by_chain) export(mcmc_hex) @@ -84,6 +86,7 @@ export(mcmc_intervals_data) export(mcmc_neff) export(mcmc_neff_data) export(mcmc_neff_hist) +export(mcmc_nuts) export(mcmc_nuts_acceptance) export(mcmc_nuts_divergence) export(mcmc_nuts_energy) @@ -95,6 +98,7 @@ export(mcmc_parcoord_data) export(mcmc_rank_ecdf) export(mcmc_rank_hist) export(mcmc_rank_overlay) +export(mcmc_recover) export(mcmc_recover_hist) export(mcmc_recover_intervals) export(mcmc_recover_scatter) @@ -105,6 +109,7 @@ export(mcmc_scatter) export(mcmc_trace) export(mcmc_trace_data) export(mcmc_trace_highlight) +export(mcmc_trace_w) export(mcmc_violin) export(neff_ratio) export(nuts_params) @@ -125,9 +130,12 @@ export(ppc_data) export(ppc_dens) export(ppc_dens_overlay) export(ppc_dens_overlay_grouped) +export(ppc_discrete) +export(ppc_dist) export(ppc_dots) export(ppc_ecdf_overlay) export(ppc_ecdf_overlay_grouped) +export(ppc_error) export(ppc_error_binned) export(ppc_error_data) export(ppc_error_hist) @@ -144,6 +152,7 @@ export(ppc_intervals_data) export(ppc_intervals_grouped) export(ppc_km_overlay) export(ppc_km_overlay_grouped) +export(ppc_loo) export(ppc_loo_intervals) export(ppc_loo_pit) export(ppc_loo_pit_data) @@ -173,6 +182,7 @@ export(ppd_boxplot) export(ppd_data) export(ppd_dens) export(ppd_dens_overlay) +export(ppd_dist) export(ppd_dots) export(ppd_ecdf_overlay) export(ppd_freqpoly) diff --git a/NEWS.md b/NEWS.md index b3e8922c..5df27028 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,9 @@ # bayesplot (development version) +* New unified wrapper functions for plot families (`ppc_error()`, `ppc_dist()`, + `ppc_discrete()`, `ppc_loo()`, `ppd_dist()`, `mcmc_dist()`, `mcmc_trace_w()`, + `mcmc_diag()`, `mcmc_nuts()`, `mcmc_recover()`) that dispatch to the + underlying plotting functions via a `type` argument (#433) * Documentation added for all exported `*_data()` functions (#209) * Improved documentation for `binwidth`, `bins`, and `breaks` arguments to clarify they are passed to `ggplot2::geom_area()` and `ggdist::stat_dots()` in addition to `ggplot2::geom_histogram()` * Improved documentation for `freq` argument to clarify it applies to frequency polygons in addition to histograms diff --git a/R/wrappers.R b/R/wrappers.R new file mode 100644 index 00000000..1393b74e --- /dev/null +++ b/R/wrappers.R @@ -0,0 +1,273 @@ +#' Unified wrapper functions for bayesplot families +#' +#' @description +#' High-level wrapper functions that consolidate related plotting functions +#' within each bayesplot family into a single entry point. Each wrapper uses a +#' `type` argument to select the specific plot variant, and (where applicable) a +#' `grouped` flag to switch to a faceted-by-group version. +#' +#' All additional arguments are forwarded to the underlying plotting function +#' via `...`. +#' +#' @details +#' These wrappers are thin dispatchers. The underlying functions and their full +#' documentation are still available directly (e.g., [ppc_error_hist()], +#' [mcmc_rhat()]). Use [available_ppc()], [available_ppd()], or +#' [available_mcmc()] to discover all plotting functions. +#' +#' @name bayesplot-wrappers +#' @seealso [available_ppc()], [available_ppd()], [available_mcmc()], +#' [PPC-overview] +NULL + + +# Internal dispatcher ----------------------------------------------------- + +.dispatch <- function(prefix, type, is_grouped, ...) { + suffix <- type + if (is_grouped) { + suffix <- paste0(suffix, "_grouped") + } + fn_name <- paste0(prefix, "_", suffix) + fn <- tryCatch( + match.fun(fn_name), + error = function(e) { + if (is_grouped) { + abort(paste0( + "No grouped variant available for type '", type, "'. ", + "Try `grouped = FALSE` or a different `type`." + )) + } + abort(paste0( + "'", type, "' is not a valid type for `", prefix, "()`. ", + "See the documentation for supported types." + )) + } + ) + fn(...) +} + + +# PPC wrappers ------------------------------------------------------------ + +#' @rdname bayesplot-wrappers +#' @export +#' +#' @param type A string selecting the specific plot variant within the family. +#' Partial matching is supported. See each wrapper's section below for valid +#' values. +#' @param grouped If `TRUE`, use the grouped (faceted) variant of the plot when +#' one exists. A `group` argument must also be provided via `...`. +#' @param ... Arguments passed to the underlying plotting function +#' (e.g., `y`, `yrep`, `prob`, `size`, etc.). +#' +#' @section PPC error plots (`ppc_error()`): +#' Dispatches to the [PPC-errors] family. +#' +#' Valid types: `"hist"`, `"scatter"`, `"scatter_avg"`, `"binned"`. +#' Grouped variants exist for `"hist"` and `"scatter_avg"`. +#' +#' @examples +#' y <- example_y_data() +#' yrep <- example_yrep_draws() +#' +#' \donttest{ +#' ppc_error(y = y, yrep = yrep[1:3, ], type = "hist") +#' } +#' +ppc_error <- function(type = c("hist", "scatter", "scatter_avg", "binned"), + grouped = FALSE, + ...) { + type <- match.arg(type) + .dispatch("ppc_error", type, grouped, ...) +} + + +#' @rdname bayesplot-wrappers +#' @export +#' +#' @section PPC distribution plots (`ppc_dist()`): +#' Dispatches to the [PPC-distributions] family. +#' +#' Valid types: `"hist"`, `"dens"`, `"dens_overlay"`, `"ecdf_overlay"`, +#' `"freqpoly"`, `"boxplot"`, `"violin"`, `"dots"`, `"pit_ecdf"`. +#' Grouped variants exist for `"dens_overlay"`, `"ecdf_overlay"`, +#' `"freqpoly"`, `"violin"`, and `"pit_ecdf"`. +#' +#' Note: `"violin"` only exists as a grouped variant, so `grouped = TRUE` is +#' required. +#' +#' @examples +#' \donttest{ +#' ppc_dist(y = y, yrep = yrep[1:8, ], type = "hist") +#' ppc_dist(y = y, yrep = yrep, type = "dens_overlay") +#' } +#' +ppc_dist <- function(type = c("hist", "dens", "dens_overlay", "ecdf_overlay", + "freqpoly", "boxplot", "violin", "dots", + "pit_ecdf"), + grouped = FALSE, + ...) { + type <- match.arg(type) + if (type == "violin" && !grouped) { + abort(paste0( + "`ppc_violin_grouped()` only exists as a grouped variant. ", + "Use `grouped = TRUE`." + )) + } + .dispatch("ppc", type, grouped, ...) +} + + +#' @rdname bayesplot-wrappers +#' @export +#' +#' @section PPC discrete plots (`ppc_discrete()`): +#' Dispatches to the [PPC-discrete] family. +#' +#' Valid types: `"bars"`, `"rootogram"`. +#' A grouped variant exists for `"bars"`. +#' +ppc_discrete <- function(type = c("bars", "rootogram"), + grouped = FALSE, + ...) { + type <- match.arg(type) + .dispatch("ppc", type, grouped, ...) +} + + +#' @rdname bayesplot-wrappers +#' @export +#' +#' @section PPC LOO plots (`ppc_loo()`): +#' Dispatches to the [PPC-loo] family. +#' +#' Valid types: `"pit_overlay"`, `"pit_qq"`, `"pit_ecdf"`, +#' `"intervals"`, `"ribbon"`. +#' No grouped variants. +#' +ppc_loo <- function(type = c("pit_overlay", "pit_qq", "pit_ecdf", + "intervals", "ribbon"), + ...) { + type <- match.arg(type) + fn_name <- paste0("ppc_loo_", type) + fn <- match.fun(fn_name) + fn(...) +} + + +# PPD wrappers ------------------------------------------------------------- + +#' @rdname bayesplot-wrappers +#' @export +#' +#' @section PPD distribution plots (`ppd_dist()`): +#' Dispatches to the [PPD-distributions] family. +#' +#' Valid types: `"hist"`, `"dens"`, `"dens_overlay"`, `"ecdf_overlay"`, +#' `"freqpoly"`, `"boxplot"`, `"dots"`. +#' A grouped variant exists for `"freqpoly"`. +#' +ppd_dist <- function(type = c("hist", "dens", "dens_overlay", "ecdf_overlay", + "freqpoly", "boxplot", "dots"), + grouped = FALSE, + ...) { + type <- match.arg(type) + .dispatch("ppd", type, grouped, ...) +} + + +# MCMC wrappers ------------------------------------------------------------ + +#' @rdname bayesplot-wrappers +#' @export +#' +#' @section MCMC distribution plots (`mcmc_dist()`): +#' Dispatches to the [MCMC-distributions] family. +#' +#' Valid types: `"hist"`, `"dens"`, `"hist_by_chain"`, `"dens_overlay"`, +#' `"dens_chains"`, `"violin"`, `"dots"`, `"dots_by_chain"`. +#' +mcmc_dist <- function(type = c("hist", "dens", "hist_by_chain", + "dens_overlay", "dens_chains", "violin", + "dots", "dots_by_chain"), + ...) { + type <- match.arg(type) + fn_name <- paste0("mcmc_", type) + fn <- match.fun(fn_name) + fn(...) +} + + +#' @rdname bayesplot-wrappers +#' @export +#' +#' @section MCMC trace plots (`mcmc_trace_w()`): +#' Dispatches to the [MCMC-traces] family. +#' +#' Valid types: `"trace"`, `"trace_highlight"`, `"rank_overlay"`, +#' `"rank_hist"`, `"rank_ecdf"`. +#' +mcmc_trace_w <- function(type = c("trace", "trace_highlight", "rank_overlay", + "rank_hist", "rank_ecdf"), + ...) { + type <- match.arg(type) + fn_name <- paste0("mcmc_", type) + fn <- match.fun(fn_name) + fn(...) +} + + +#' @rdname bayesplot-wrappers +#' @export +#' +#' @section MCMC diagnostic plots (`mcmc_diag()`): +#' Dispatches to the [MCMC-diagnostics] family. +#' +#' Valid types: `"rhat"`, `"rhat_hist"`, `"neff"`, `"neff_hist"`, +#' `"acf"`, `"acf_bar"`. +#' +mcmc_diag <- function(type = c("rhat", "rhat_hist", "neff", "neff_hist", + "acf", "acf_bar"), + ...) { + type <- match.arg(type) + fn_name <- paste0("mcmc_", type) + fn <- match.fun(fn_name) + fn(...) +} + + +#' @rdname bayesplot-wrappers +#' @export +#' +#' @section MCMC NUTS diagnostic plots (`mcmc_nuts()`): +#' Dispatches to the [MCMC-nuts] family. +#' +#' Valid types: `"acceptance"`, `"divergence"`, `"stepsize"`, +#' `"treedepth"`, `"energy"`. +#' +mcmc_nuts <- function(type = c("acceptance", "divergence", "stepsize", + "treedepth", "energy"), + ...) { + type <- match.arg(type) + fn_name <- paste0("mcmc_nuts_", type) + fn <- match.fun(fn_name) + fn(...) +} + + +#' @rdname bayesplot-wrappers +#' @export +#' +#' @section MCMC recovery plots (`mcmc_recover()`): +#' Dispatches to the [MCMC-recover] family. +#' +#' Valid types: `"intervals"`, `"scatter"`, `"hist"`. +#' +mcmc_recover <- function(type = c("intervals", "scatter", "hist"), + ...) { + type <- match.arg(type) + fn_name <- paste0("mcmc_recover_", type) + fn <- match.fun(fn_name) + fn(...) +} diff --git a/man/bayesplot-wrappers.Rd b/man/bayesplot-wrappers.Rd new file mode 100644 index 00000000..c3188b69 --- /dev/null +++ b/man/bayesplot-wrappers.Rd @@ -0,0 +1,195 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/wrappers.R +\name{bayesplot-wrappers} +\alias{bayesplot-wrappers} +\alias{ppc_error} +\alias{ppc_dist} +\alias{ppc_discrete} +\alias{ppc_loo} +\alias{ppd_dist} +\alias{mcmc_dist} +\alias{mcmc_trace_w} +\alias{mcmc_diag} +\alias{mcmc_nuts} +\alias{mcmc_recover} +\title{Unified wrapper functions for bayesplot families} +\usage{ +ppc_error( + type = c("hist", "scatter", "scatter_avg", "binned"), + grouped = FALSE, + ... +) + +ppc_dist( + type = c("hist", "dens", "dens_overlay", "ecdf_overlay", "freqpoly", "boxplot", + "violin", "dots", "pit_ecdf"), + grouped = FALSE, + ... +) + +ppc_discrete(type = c("bars", "rootogram"), grouped = FALSE, ...) + +ppc_loo( + type = c("pit_overlay", "pit_qq", "pit_ecdf", "intervals", "ribbon"), + ... +) + +ppd_dist( + type = c("hist", "dens", "dens_overlay", "ecdf_overlay", "freqpoly", "boxplot", "dots"), + grouped = FALSE, + ... +) + +mcmc_dist( + type = c("hist", "dens", "hist_by_chain", "dens_overlay", "dens_chains", "violin", + "dots", "dots_by_chain"), + ... +) + +mcmc_trace_w( + type = c("trace", "trace_highlight", "rank_overlay", "rank_hist", "rank_ecdf"), + ... +) + +mcmc_diag( + type = c("rhat", "rhat_hist", "neff", "neff_hist", "acf", "acf_bar"), + ... +) + +mcmc_nuts( + type = c("acceptance", "divergence", "stepsize", "treedepth", "energy"), + ... +) + +mcmc_recover(type = c("intervals", "scatter", "hist"), ...) +} +\arguments{ +\item{type}{A string selecting the specific plot variant within the family. +Partial matching is supported. See each wrapper's section below for valid +values.} + +\item{grouped}{If \code{TRUE}, use the grouped (faceted) variant of the plot when +one exists. A \code{group} argument must also be provided via \code{...}.} + +\item{...}{Arguments passed to the underlying plotting function +(e.g., \code{y}, \code{yrep}, \code{prob}, \code{size}, etc.).} +} +\description{ +High-level wrapper functions that consolidate related plotting functions +within each bayesplot family into a single entry point. Each wrapper uses a +\code{type} argument to select the specific plot variant, and (where applicable) a +\code{grouped} flag to switch to a faceted-by-group version. + +All additional arguments are forwarded to the underlying plotting function +via \code{...}. +} +\details{ +These wrappers are thin dispatchers. The underlying functions and their full +documentation are still available directly (e.g., \code{\link[=ppc_error_hist]{ppc_error_hist()}}, +\code{\link[=mcmc_rhat]{mcmc_rhat()}}). Use \code{\link[=available_ppc]{available_ppc()}}, \code{\link[=available_ppd]{available_ppd()}}, or +\code{\link[=available_mcmc]{available_mcmc()}} to discover all plotting functions. +} +\section{PPC error plots (\code{ppc_error()})}{ + +Dispatches to the \link{PPC-errors} family. + +Valid types: \code{"hist"}, \code{"scatter"}, \code{"scatter_avg"}, \code{"binned"}. +Grouped variants exist for \code{"hist"} and \code{"scatter_avg"}. +} + +\section{PPC distribution plots (\code{ppc_dist()})}{ + +Dispatches to the \link{PPC-distributions} family. + +Valid types: \code{"hist"}, \code{"dens"}, \code{"dens_overlay"}, \code{"ecdf_overlay"}, +\code{"freqpoly"}, \code{"boxplot"}, \code{"violin"}, \code{"dots"}, \code{"pit_ecdf"}. +Grouped variants exist for \code{"dens_overlay"}, \code{"ecdf_overlay"}, +\code{"freqpoly"}, \code{"violin"}, and \code{"pit_ecdf"}. + +Note: \code{"violin"} only exists as a grouped variant, so \code{grouped = TRUE} is +required. +} + +\section{PPC discrete plots (\code{ppc_discrete()})}{ + +Dispatches to the \link{PPC-discrete} family. + +Valid types: \code{"bars"}, \code{"rootogram"}. +A grouped variant exists for \code{"bars"}. +} + +\section{PPC LOO plots (\code{ppc_loo()})}{ + +Dispatches to the \link{PPC-loo} family. + +Valid types: \code{"pit_overlay"}, \code{"pit_qq"}, \code{"pit_ecdf"}, +\code{"intervals"}, \code{"ribbon"}. +No grouped variants. +} + +\section{PPD distribution plots (\code{ppd_dist()})}{ + +Dispatches to the \link{PPD-distributions} family. + +Valid types: \code{"hist"}, \code{"dens"}, \code{"dens_overlay"}, \code{"ecdf_overlay"}, +\code{"freqpoly"}, \code{"boxplot"}, \code{"dots"}. +A grouped variant exists for \code{"freqpoly"}. +} + +\section{MCMC distribution plots (\code{mcmc_dist()})}{ + +Dispatches to the \link{MCMC-distributions} family. + +Valid types: \code{"hist"}, \code{"dens"}, \code{"hist_by_chain"}, \code{"dens_overlay"}, +\code{"dens_chains"}, \code{"violin"}, \code{"dots"}, \code{"dots_by_chain"}. +} + +\section{MCMC trace plots (\code{mcmc_trace_w()})}{ + +Dispatches to the \link{MCMC-traces} family. + +Valid types: \code{"trace"}, \code{"trace_highlight"}, \code{"rank_overlay"}, +\code{"rank_hist"}, \code{"rank_ecdf"}. +} + +\section{MCMC diagnostic plots (\code{mcmc_diag()})}{ + +Dispatches to the \link{MCMC-diagnostics} family. + +Valid types: \code{"rhat"}, \code{"rhat_hist"}, \code{"neff"}, \code{"neff_hist"}, +\code{"acf"}, \code{"acf_bar"}. +} + +\section{MCMC NUTS diagnostic plots (\code{mcmc_nuts()})}{ + +Dispatches to the \link{MCMC-nuts} family. + +Valid types: \code{"acceptance"}, \code{"divergence"}, \code{"stepsize"}, +\code{"treedepth"}, \code{"energy"}. +} + +\section{MCMC recovery plots (\code{mcmc_recover()})}{ + +Dispatches to the \link{MCMC-recover} family. + +Valid types: \code{"intervals"}, \code{"scatter"}, \code{"hist"}. +} + +\examples{ +y <- example_y_data() +yrep <- example_yrep_draws() + +\donttest{ +ppc_error(y = y, yrep = yrep[1:3, ], type = "hist") +} + +\donttest{ +ppc_dist(y = y, yrep = yrep[1:8, ], type = "hist") +ppc_dist(y = y, yrep = yrep, type = "dens_overlay") +} + +} +\seealso{ +\code{\link[=available_ppc]{available_ppc()}}, \code{\link[=available_ppd]{available_ppd()}}, \code{\link[=available_mcmc]{available_mcmc()}}, +\link{PPC-overview} +} diff --git a/tests/testthat/test-wrappers.R b/tests/testthat/test-wrappers.R new file mode 100644 index 00000000..4e395588 --- /dev/null +++ b/tests/testthat/test-wrappers.R @@ -0,0 +1,123 @@ +source(test_path("data-for-ppc-tests.R")) +source(test_path("data-for-mcmc-tests.R")) + +# PPC wrappers ------------------------------------------------------------- + +test_that("ppc_error dispatches correctly", { + skip_if_not_installed("rstantools") + expect_gg(ppc_error(y = y, yrep = yrep[1:3, ], type = "hist")) + expect_gg(ppc_error(y = y, yrep = yrep[1, , drop = FALSE], type = "scatter")) + expect_gg(ppc_error(y = y, yrep = yrep, type = "scatter_avg")) + expect_gg(ppc_error(y = y, yrep = yrep, type = "binned")) + expect_gg(ppc_error(y = y, yrep = yrep[1:3, ], type = "hist", + grouped = TRUE, group = group)) + expect_error(ppc_error(y = y, yrep = yrep, type = "binned", grouped = TRUE), + "No grouped variant") +}) + +test_that("ppc_dist dispatches correctly", { + expect_gg(ppc_dist(y = y, yrep = yrep[1:3, ], type = "hist")) + expect_gg(ppc_dist(y = y, yrep = yrep[1:3, ], type = "dens")) + expect_gg(ppc_dist(y = y, yrep = yrep, type = "dens_overlay")) + expect_gg(ppc_dist(y = y, yrep = yrep, type = "ecdf_overlay")) + expect_gg(ppc_dist(y = y, yrep = yrep[1:3, ], type = "freqpoly")) + expect_gg(ppc_dist(y = y, yrep = yrep[1:3, ], type = "boxplot")) + expect_gg(ppc_dist(y = y, yrep = yrep, type = "dens_overlay", + grouped = TRUE, group = group)) + expect_error(ppc_dist(y = y, yrep = yrep, type = "violin"), + "grouped variant") + expect_gg(ppc_dist(y = y, yrep = yrep, type = "violin", + grouped = TRUE, group = group)) +}) + +test_that("ppc_discrete dispatches correctly", { + expect_gg(ppc_discrete(y = y2, yrep = yrep2, type = "bars")) + expect_gg(ppc_discrete(y = y2, yrep = yrep2, type = "bars", + grouped = TRUE, group = group2)) + expect_error(ppc_discrete(y = y2, yrep = yrep2, type = "rootogram", + grouped = TRUE), + "No grouped variant") +}) + +test_that("ppc_loo dispatches correctly", { + skip_if_not_installed("loo") + skip_if_not_installed("rstantools") + skip_if(packageVersion("rstantools") <= "2.4.0") + + expect_gg(suppressMessages( + ppc_loo(y = vdiff_loo_y, yrep = vdiff_loo_yrep, lw = vdiff_loo_lw, + type = "pit_overlay") + )) + expect_gg(suppressWarnings( + ppc_loo(y = vdiff_loo_y, yrep = vdiff_loo_yrep, lw = vdiff_loo_lw, + type = "pit_qq") + )) +}) + +test_that("ppc_error partial matching works", { + skip_if_not_installed("rstantools") + expect_gg(ppc_error(y = y, yrep = yrep[1:3, ], type = "hi")) +}) + +# PPD wrappers ------------------------------------------------------------- + +test_that("ppd_dist dispatches correctly", { + expect_gg(ppd_dist(ypred = yrep[1:3, ], type = "hist")) + expect_gg(ppd_dist(ypred = yrep, type = "dens_overlay")) + expect_gg(ppd_dist(ypred = yrep[1:3, ], type = "boxplot")) + expect_gg(ppd_dist(ypred = yrep[1:3, ], type = "freqpoly", + grouped = TRUE, group = group)) +}) + +# MCMC wrappers ------------------------------------------------------------ + +test_that("mcmc_dist dispatches correctly", { + expect_gg(mcmc_dist(x = dframe, type = "hist")) + expect_gg(mcmc_dist(x = dframe, type = "dens")) + expect_gg(mcmc_dist(x = arr, type = "dens_overlay")) +}) + +test_that("mcmc_trace_w dispatches correctly", { + expect_gg(mcmc_trace_w(x = arr, type = "trace", pars = "sigma")) + expect_gg(mcmc_trace_w(x = arr, type = "rank_overlay", pars = "sigma")) +}) + +test_that("mcmc_diag dispatches correctly", { + rhat_val <- runif(10, 1, 1.5) + expect_gg(mcmc_diag(rhat = rhat_val, type = "rhat")) + expect_gg(mcmc_diag(rhat = rhat_val, type = "rhat_hist")) + + neff_val <- runif(10, 0, 1) + expect_gg(mcmc_diag(ratio = neff_val, type = "neff")) +}) + +test_that("mcmc_nuts dispatches correctly", { + skip_if_not_installed("gridExtra") + np <- data.frame( + Parameter = rep(c("accept_stat__", "stepsize__", "treedepth__", + "n_leapfrog__", "divergent__", "energy__"), each = 100), + Value = rnorm(600), + Chain = rep(1, 600), + Iteration = rep(1:100, 6) + ) + lp <- data.frame( + Parameter = rep("lp__", 100), + Value = rnorm(100), + Chain = rep(1, 100), + Iteration = 1:100 + ) + expect_gg(mcmc_nuts(x = np, lp = lp, type = "acceptance")) + expect_gg(mcmc_nuts(x = np, lp = lp, type = "divergence")) + expect_gg(mcmc_nuts(x = np, type = "energy")) +}) + +test_that("mcmc_recover dispatches correctly", { + true_vals <- colMeans(mat) + expect_gg(mcmc_recover(x = mat, true = true_vals, type = "intervals")) + expect_gg(mcmc_recover(x = mat, true = true_vals, type = "scatter")) + expect_gg(mcmc_recover(x = mat, true = true_vals, type = "hist")) +}) + +test_that("invalid type errors helpfully", { + expect_error(ppc_error(y = y, yrep = yrep, type = "nonexistent")) +})