Skip to content

Replace full sort in ps_tail() tail selection with C routine#453

Draft
VisruthSK wants to merge 2 commits into
masterfrom
tail-in-c
Draft

Replace full sort in ps_tail() tail selection with C routine#453
VisruthSK wants to merge 2 commits into
masterfrom
tail-in-c

Conversation

@VisruthSK

Copy link
Copy Markdown
Member

This PR was made with assistance from Codex.

Summary

Avoid full sorting in ps_tail() tail selection by selecting only the required Pareto tail--this is rewritten in C (with Codex) and reduce memory pressure but doesn't seem to meaningfully improve wall time significantly. Note that since this is a C routine not C++, no dependencies are added.

Benchmark

Settings:

  • smooth_draws = FALSE
  • tail = "right"
  • ndraws_tail = ceiling(pmin(0.2 * S, 3 * sqrt(S)))
  • 5 seeds: 0, 42, 451, 1984, 4444
size S ndraws_tail speedup median speedup range memory reduction
1x1000 1,000 95 1.28x 0.946–1.72x 1.95x
2x1000 2,000 135 1.33x 1.12–1.74x 2.32x
4x1000 4,000 190 1.20x 1.06–2.09x 2.84x
4x2000 8,000 269 1.48x 1.14–2.19x 3.49x
4x4000 16,000 380 1.46x 1.43–1.80x 4.38x
4x10000 40,000 600 1.50x 1.13–1.69x 5.95x
stress_4x25000 100,000 949 2.20x 1.77–2.80x 8.16x
stress_4x250000 1,000,000 3,000 3.62x 2.86–4.22x 17.7x
Benchmark code

This was run with posterior 1.7.1 compared against this branch's version of posterior.

library(tidyverse)
library(bench)
library(posterior)

x_base <- posterior::example_draws("eight_schools") |>
  posterior::as_draws_matrix() |>
  (\(x) as.numeric(x[, "tau"]))()

original_ps_tail <- posterior::ps_tail

devtools::load_all(reset = TRUE, recompile = TRUE)

sizes <- tibble::tibble(
  size_label = c(
    "1x1000",
    "2x1000",
    "4x1000",
    "4x2000",
    "4x4000",
    "4x10000",
    "stress_4x25000",
    "stress_4x250000"
  ),
  chains = c(1L, 2L, 4L, 4L, 4L, 4L, 4L, 4L),
  draws_per_chain = c(
    1000L,
    1000L,
    1000L,
    2000L,
    4000L,
    10000L,
    25000L,
    250000L
  )
) |>
  dplyr::mutate(
    S = chains * draws_per_chain,
    ndraws_tail = as.integer(ceiling(pmin(0.2 * S, 3 * sqrt(S)))),
    tail_fraction = ndraws_tail / S,
    iterations = dplyr::if_else(S <= 4000L, 1000L, 300L)
  )

warmup_one <- function(x, ndraws_tail) {
  invisible(
    original_ps_tail(
      x,
      ndraws_tail = ndraws_tail,
      smooth_draws = FALSE,
      tail = "right"
    )
  )

  invisible(
    ps_tail(
      x,
      ndraws_tail = ndraws_tail,
      smooth_draws = FALSE,
      tail = "right"
    )
  )
}

bench_one <- function(x, ndraws_tail, iterations) {
  bench::mark(
    original = original_ps_tail(
      x,
      ndraws_tail = ndraws_tail,
      smooth_draws = FALSE,
      tail = "right"
    ),
    new = ps_tail(
      x,
      ndraws_tail = ndraws_tail,
      smooth_draws = FALSE,
      tail = "right"
    ),
    iterations = iterations,
    check = TRUE
  )
}

run_ps_tail_benchmark <- function(seed) {
  set.seed(seed)

  bench_sizes <- sizes |>
    dplyr::mutate(
      x = purrr::map(S, \(S) sample(x_base, size = S, replace = TRUE))
    )

  purrr::pwalk(
    list(bench_sizes$x, bench_sizes$ndraws_tail),
    warmup_one
  )

  results <- purrr::pmap(
    list(bench_sizes$x, bench_sizes$ndraws_tail, bench_sizes$iterations),
    \(x, ndraws_tail, iterations) {
      bench_one(
        x = x,
        ndraws_tail = ndraws_tail,
        iterations = iterations
      )
    }
  )

  names(results) <- bench_sizes$size_label

  absolute <- purrr::imap_dfr(
    results,
    \(result, size_label) {
      summary(result, relative = FALSE) |>
        dplyr::filter(as.character(expression) != "gc") |>
        dplyr::mutate(size_label = size_label, .before = expression)
    }
  ) |>
    dplyr::left_join(
      bench_sizes |>
        dplyr::select(
          size_label,
          chains,
          draws_per_chain,
          S,
          ndraws_tail,
          tail_fraction,
          iterations
        ),
      by = "size_label"
    ) |>
    dplyr::mutate(seed = seed, .before = size_label) |>
    dplyr::relocate(
      chains,
      draws_per_chain,
      S,
      ndraws_tail,
      tail_fraction,
      iterations,
      .after = size_label
    ) |>
    dplyr::select(-c(result, memory, time, gc))

  speedup <- absolute |>
    dplyr::select(
      seed,
      size_label,
      chains,
      draws_per_chain,
      S,
      ndraws_tail,
      tail_fraction,
      iterations,
      expression,
      median,
      `itr/sec`,
      mem_alloc
    ) |>
    tidyr::pivot_wider(
      names_from = expression,
      values_from = c(median, `itr/sec`, mem_alloc)
    ) |>
    dplyr::mutate(
      speedup = as.numeric(median_original) / as.numeric(median_new),
      itr_sec_ratio = `itr/sec_new` / `itr/sec_original`,
      memory_ratio = as.numeric(mem_alloc_original) / as.numeric(mem_alloc_new)
    )

  list(
    seed = seed,
    absolute = absolute,
    speedup = speedup
  )
}

speedup_by_seed <- purrr::map(
  c(0, 42, 451, 1984, 4444),
  run_ps_tail_benchmark
) |>
  purrr::map_dfr("speedup")

speedup_summary <- speedup_by_seed |>
  dplyr::group_by(
    size_label,
    chains,
    draws_per_chain,
    S,
    ndraws_tail,
    tail_fraction,
    iterations
  ) |>
  dplyr::summarise(
    n_seeds = dplyr::n_distinct(seed),
    speedup_min = min(speedup),
    speedup_median = median(speedup),
    speedup_max = max(speedup),
    itr_sec_ratio_min = min(itr_sec_ratio),
    itr_sec_ratio_median = median(itr_sec_ratio),
    itr_sec_ratio_max = max(itr_sec_ratio),
    memory_ratio = dplyr::first(memory_ratio),
    .groups = "drop"
  ) |>
  dplyr::arrange(S)

speedup_by_seed |> print(n = Inf)
#> # A tibble: 40 × 17
#>     seed size_label      chains draws_per_chain       S ndraws_tail tail_fraction iterations median_original median_new `itr/sec_original` `itr/sec_new` mem_alloc_original mem_alloc_new speedup itr_sec_ratio memory_ratio
#>    <dbl> <chr>            <int>           <int>   <int>       <int>         <dbl>      <int>        <bch:tm>   <bch:tm>              <dbl>         <dbl>          <bch:byt>     <bch:byt>   <dbl>         <dbl>        <dbl>
#>  1     0 1x1000               1            1000    1000          95       0.095         1000         391.6µs    414.1µs             2683.         2304.             69.59KB       35.65KB   0.946         0.859         1.95
#>  2     0 2x1000               2            1000    2000         135       0.0675        1000         531.5µs   474.55µs             1879.         2217.            120.23KB       51.77KB   1.12          1.18          2.32
#>  3     0 4x1000               4            1000    4000         190       0.0475        1000         751.9µs   360.45µs             1298.         2228.            212.73KB       74.81KB   2.09          1.72          2.84
#>  4     0 4x2000               4            2000    8000         269       0.0336         300         819.3µs    552.5µs              993.         1417.            388.52KB       111.2KB   1.48          1.43          3.49
#>  5     0 4x4000               4            4000   16000         380       0.0238         300          1.16ms    813.4µs              731.          983.            721.75KB      164.92KB   1.43          1.34          4.38
#>  6     0 4x10000              4           10000   40000         600       0.015          300          2.82ms     1.66ms              311.          489.              1.64MB      282.28KB   1.69          1.57          5.95
#>  7     0 stress_4x25000       4           25000  100000         949       0.00949        300          9.55ms     5.39ms              102.          188.               3.9MB      489.31KB   1.77          1.84          8.16
#>  8     0 stress_4x250000      4          250000 1000000        3000       0.003          300          82.7ms    19.59ms               11.7          47.2            36.34MB        2.05MB   4.22          4.03         17.7 
#>  9    42 1x1000               1            1000    1000          95       0.095         1000         225.6µs      176µs             3387.         4312.             69.59KB       35.65KB   1.28          1.27          1.95
#> 10    42 2x1000               2            1000    2000         135       0.0675        1000         334.6µs      234µs             2411.         3494.            120.23KB       51.77KB   1.43          1.45          2.32
#> 11    42 4x1000               4            1000    4000         190       0.0475        1000         471.4µs    443.8µs             1694.         1888.            212.73KB       74.81KB   1.06          1.11          2.84
#> 12    42 4x2000               4            2000    8000         269       0.0336         300        856.75µs    599.7µs              876.         1461.            388.52KB       111.2KB   1.43          1.67          3.49
#> 13    42 4x4000               4            4000   16000         380       0.0238         300          1.06ms    725.9µs              891.         1264.            721.75KB      164.92KB   1.46          1.42          4.38
#> 14    42 4x10000              4           10000   40000         600       0.015          300           1.7ms      1.2ms              568.          788.              1.64MB      282.28KB   1.41          1.39          5.95
#> 15    42 stress_4x25000       4           25000  100000         949       0.00949        300          7.65ms     2.73ms              120.          313.               3.9MB      489.31KB   2.80          2.60          8.16
#> 16    42 stress_4x250000      4          250000 1000000        3000       0.003          300         66.62ms    23.28ms               14.4          43.0            36.34MB        2.05MB   2.86          2.99         17.7 
#> 17   451 1x1000               1            1000    1000          95       0.095         1000        339.55µs    282.4µs             2919.         3504.             69.59KB       35.65KB   1.20          1.20          1.95
#> 18   451 2x1000               2            1000    2000         135       0.0675        1000        526.95µs    396.7µs             1947.         2498.            120.23KB       51.77KB   1.33          1.28          2.32
#> 19   451 4x1000               4            1000    4000         190       0.0475        1000         726.5µs    607.6µs             1311.         1680.            212.73KB       74.81KB   1.20          1.28          2.84
#> 20   451 4x2000               4            2000    8000         269       0.0336         300          1.42ms    840.6µs              695.         1156.            388.52KB       111.2KB   1.69          1.66          3.49
#> 21   451 4x4000               4            4000   16000         380       0.0238         300           2.1ms     1.17ms              467.          818.            721.75KB      164.92KB   1.80          1.75          4.38
#> 22   451 4x10000              4           10000   40000         600       0.015          300          2.78ms     1.78ms              333.          475.              1.64MB      282.28KB   1.56          1.43          5.95
#> 23   451 stress_4x25000       4           25000  100000         949       0.00949        300          7.35ms      3.1ms              125.          283.               3.9MB      489.31KB   2.38          2.26          8.16
#> 24   451 stress_4x250000      4          250000 1000000        3000       0.003          300         82.38ms    23.24ms               12.2          42.7            36.34MB        2.05MB   3.54          3.50         17.7 
#> 25  1984 1x1000               1            1000    1000          95       0.095         1000         206.9µs    160.5µs             3858.         5007.             69.59KB       35.65KB   1.29          1.30          1.95
#> 26  1984 2x1000               2            1000    2000         135       0.0675        1000         399.3µs    229.7µs             2246.         3302.            120.23KB       51.77KB   1.74          1.47          2.32
#> 27  1984 4x1000               4            1000    4000         190       0.0475        1000         458.9µs   392.65µs             1728.         2012.            212.73KB       74.81KB   1.17          1.16          2.84
#> 28  1984 4x2000               4            2000    8000         269       0.0336         300          1.27ms    579.9µs              814.         1361.            388.52KB       111.2KB   2.19          1.67          3.49
#> 29  1984 4x4000               4            4000   16000         380       0.0238         300          1.33ms    918.5µs              659.          917.            721.75KB      164.92KB   1.44          1.39          4.38
#> 30  1984 4x10000              4           10000   40000         600       0.015          300          2.76ms     2.45ms              322.          386.              1.64MB      282.28KB   1.13          1.20          5.95
#> 31  1984 stress_4x25000       4           25000  100000         949       0.00949        300          9.38ms     4.27ms              104.          232.               3.9MB      489.31KB   2.20          2.22          8.16
#> 32  1984 stress_4x250000      4          250000 1000000        3000       0.003          300         86.39ms    22.96ms               11.6          44.1            36.34MB        2.05MB   3.76          3.82         17.7 
#> 33  4444 1x1000               1            1000    1000          95       0.095         1000         288.4µs    167.8µs             3044.         5338.             69.59KB       35.65KB   1.72          1.75          1.95
#> 34  4444 2x1000               2            1000    2000         135       0.0675        1000        224.65µs    195.5µs             3829.         3958.            120.23KB       51.77KB   1.15          1.03          2.32
#> 35  4444 4x1000               4            1000    4000         190       0.0475        1000         795.1µs    551.8µs             1262.         1729.            212.73KB       74.81KB   1.44          1.37          2.84
#> 36  4444 4x2000               4            2000    8000         269       0.0336         300          1.16ms     1.02ms              894.          964.            388.52KB       111.2KB   1.14          1.08          3.49
#> 37  4444 4x4000               4            4000   16000         380       0.0238         300          1.64ms    939.5µs              624.          920.            721.75KB      164.92KB   1.74          1.47          4.38
#> 38  4444 4x10000              4           10000   40000         600       0.015          300          2.79ms     1.86ms              348.          456.              1.64MB      282.28KB   1.50          1.31          5.95
#> 39  4444 stress_4x25000       4           25000  100000         949       0.00949        300          8.28ms     4.61ms              121.          217.               3.9MB      489.31KB   1.80          1.80          8.16
#> 40  4444 stress_4x250000      4          250000 1000000        3000       0.003          300         87.32ms     24.1ms               11.5          41.2            36.34MB        2.05MB   3.62          3.59         17.7 

speedup_summary |> print(n = Inf)
#> # A tibble: 8 × 15
#>   size_label      chains draws_per_chain       S ndraws_tail tail_fraction iterations n_seeds speedup_min speedup_median speedup_max itr_sec_ratio_min itr_sec_ratio_median itr_sec_ratio_max memory_ratio
#>   <chr>            <int>           <int>   <int>       <int>         <dbl>      <int>   <int>       <dbl>          <dbl>       <dbl>             <dbl>                <dbl>             <dbl>        <dbl>
#> 1 1x1000               1            1000    1000          95       0.095         1000       5       0.946           1.28        1.72             0.859                 1.27              1.75         1.95
#> 2 2x1000               2            1000    2000         135       0.0675        1000       5       1.12            1.33        1.74             1.03                  1.28              1.47         2.32
#> 3 4x1000               4            1000    4000         190       0.0475        1000       5       1.06            1.20        2.09             1.11                  1.28              1.72         2.84
#> 4 4x2000               4            2000    8000         269       0.0336         300       5       1.14            1.48        2.19             1.08                  1.66              1.67         3.49
#> 5 4x4000               4            4000   16000         380       0.0238         300       5       1.43            1.46        1.80             1.34                  1.42              1.75         4.38
#> 6 4x10000              4           10000   40000         600       0.015          300       5       1.13            1.50        1.69             1.20                  1.39              1.57         5.95
#> 7 stress_4x25000       4           25000  100000         949       0.00949        300       5       1.77            2.20        2.80             1.80                  2.22              2.60         8.16
#> 8 stress_4x250000      4          250000 1000000        3000       0.003          300       5       2.86            3.62        4.22             2.99                  3.59              4.03        17.7

Copyright and Licensing

By submitting this pull request, the copyright holder is agreeing to license the submitted work under the following licenses:

@github-actions

Copy link
Copy Markdown

This is how benchmark results would change (along with a 95% confidence interval in relative change) if fad54c2 is merged into master:

  • ✔️as_draws_array: 140ms -> 140ms [-0.37%, +0.76%]
  • ✔️as_draws_df: 69.5ms -> 69.6ms [-0.94%, +1.29%]
  • ✔️as_draws_list: 158ms -> 157ms [-1.51%, +0.5%]
  • ✔️as_draws_matrix: 24.2ms -> 24.2ms [-0.46%, +0.85%]
  • ✔️as_draws_rvars: 112ms -> 112ms [-0.95%, +0.62%]
  • ✔️summarise_draws_100_variables: 723ms -> 724ms [-0.1%, +0.18%]
  • ✔️summarise_draws_10_variables: 81.3ms -> 81.2ms [-0.48%, +0.16%]
    Further explanation regarding interpretation and methodology can be found in the documentation.

@github-actions

Copy link
Copy Markdown

This is how benchmark results would change (along with a 95% confidence interval in relative change) if 0e755bc is merged into master:

  • ✔️as_draws_array: 189ms -> 193ms [-7.14%, +10.63%]
  • ✔️as_draws_df: 124ms -> 122ms [-12.81%, +10.17%]
  • ✔️as_draws_list: 226ms -> 231ms [-5.62%, +9.9%]
  • ✔️as_draws_matrix: 33ms -> 33ms [-3.38%, +3.64%]
  • ✔️as_draws_rvars: 147ms -> 148ms [-6.75%, +8.29%]
  • ✔️summarise_draws_100_variables: 759ms -> 760ms [-1.6%, +1.91%]
  • ✔️summarise_draws_10_variables: 84.9ms -> 84.5ms [-3.3%, +2.34%]
    Further explanation regarding interpretation and methodology can be found in the documentation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant