Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ SystemRequirements: pandoc (>= 1.12.3), pandoc-citeproc
Depends:
R (>= 4.1.0)
Imports:
dplyr (>= 0.8.0),
dplyr (>= 1.0.0),
ggplot2 (>= 3.4.0),
ggridges (>= 0.5.5),
glue,
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# bayesplot (development version)

* Replaced deprecated `dplyr` and `tidyselect` functions (`top_n`, `one_of`, `group_indices`) with their modern equivalents to ensure future compatibility. (#431)
* Documentation added for all exported `*_data()` functions (#209)
* Improved documentation for `binwidth`, `bins`, and `breaks` arguments to clarify they are passed to `ggplot2::geom_area()` and `ggdist::stat_dots()` in addition to `ggplot2::geom_histogram()`
* Improved documentation for `freq` argument to clarify it applies to frequency polygons in addition to histograms
Expand Down
25 changes: 11 additions & 14 deletions R/mcmc-intervals.R
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ mcmc_intervals_data <- function(x,

rhat_tbl <- rhat %>%
mcmc_rhat_data() %>%
select(one_of("parameter"),
select(all_of("parameter"),
rhat_value = "value",
rhat_rating = "rating",
rhat_description = "description") %>%
Expand All @@ -663,7 +663,7 @@ mcmc_intervals_data <- function(x,
# Don't import `filter`: otherwise, you get a warning when using
# `devtools::load_all(".")` because stats also has a `filter` function

#' @importFrom dplyr inner_join one_of top_n
#' @importFrom dplyr inner_join all_of slice_min
#' @rdname MCMC-intervals
#' @export
mcmc_areas_data <- function(x,
Expand Down Expand Up @@ -736,14 +736,14 @@ mcmc_areas_data <- function(x,

# Find the density values closest to the point estimate
point_ests <- intervals %>%
select(one_of("parameter", "m"))
select(all_of(c("parameter", "m")))

point_centers <- data_inner %>%
inner_join(point_ests, by = "parameter") %>%
group_by(.data$parameter) %>%
mutate(diff = abs(.data$m - .data$x)) %>%
dplyr::top_n(1, -.data$diff) %>%
select(one_of("parameter", "x", "m")) %>%
dplyr::slice_min(order_by = .data$diff, n = 1) %>%
select(all_of(c("parameter", "x", "m"))) %>%
rename(center = "x") %>%
ungroup()

Expand All @@ -765,15 +765,15 @@ mcmc_areas_data <- function(x,
}

data <- dplyr::bind_rows(data_inner, data_outer, points) %>%
select(one_of("parameter", "interval", "interval_width",
"x", "density", "scaled_density")) %>%
select(all_of(c("parameter", "interval", "interval_width",
"x", "density", "scaled_density"))) %>%
# Density scaled so the highest in entire dataframe has height 1
mutate(plotting_density = .data$density / max(.data$density))

if (rlang::has_name(intervals, "rhat_value")) {
rhat_info <- intervals %>%
select(one_of("parameter", "rhat_value",
"rhat_rating", "rhat_description"))
select(all_of(c("parameter", "rhat_value",
"rhat_rating", "rhat_description")))
data <- inner_join(data, rhat_info, by = "parameter")
}
data
Expand Down Expand Up @@ -824,18 +824,15 @@ compute_column_density <- function(df, group_vars, value_var, ...) {
syms()

# Tuck away the subgroups to compute densities on into nested dataframes
sub_df <- dplyr::select(df, !!! group_cols, !! value_var)

group_df <- df %>%
dplyr::select(!!! group_cols, !! value_var) %>%
group_by(!!! group_cols)

by_group <- group_df %>%
split(dplyr::group_indices(group_df)) %>%
dplyr::group_split() %>%
lapply(pull, !! value_var)

nested <- df %>%
dplyr::distinct(!!! group_cols) %>%
nested <- dplyr::group_keys(group_df) %>%
mutate(data = by_group)

nested$density <- lapply(nested$data, compute_interval_density, ...)
Expand Down
32 changes: 32 additions & 0 deletions tests/testthat/test-mcmc-distributions.R
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,38 @@ test_that("mcmc_dens_chains returns a ggplot object", {
expect_gg(p2)
})

test_that("mcmc_dens_chains_data computes densities per parameter-chain group", {
# Regression test for compute_column_density().
# This path groups by both parameter and chain, so it exercises the
# group_split() + group_keys() replacement introduced in PR #448.
# The goal is to verify that densities are still computed for the
# correct parameter-chain groups, in the correct grouping structure.
dens_data <- mcmc_dens_chains_data(arr, n_dens = 100)
by_group <- split(
dens_data,
interaction(dens_data$parameter, dens_data$chain, drop = TRUE, lex.order = TRUE)
)

raw <- melt_mcmc(prepare_mcmc_array(arr))
raw_by_group <- split(
raw,
interaction(raw$Parameter, raw$Chain, drop = TRUE, lex.order = TRUE)
)

manual_density <- function(df) {
dens <- density(df$Value, from = min(df$Value), to = max(df$Value), n = 100)
data.frame(x = dens$x, density = dens$y)
}

expected <- lapply(raw_by_group, manual_density)
expect_setequal(names(by_group), names(expected))
for (nm in names(expected)) {
expect_equal(by_group[[nm]]$x, expected[[nm]]$x)
expect_equal(by_group[[nm]]$density, expected[[nm]]$density, tolerance = 1e-10)
}
})


test_that("mcmc_dens_chains/mcmc_dens_overlay color chains", {
p1 <- mcmc_dens_chains(arr, pars = "beta[1]", regex_pars = "x\\:",
color_chains = FALSE)
Expand Down
Loading