Skip to content

Commit b0ee0ec

Browse files
authored
feat: Add support for arviz v0.19-v0.22 (#57)
* chore: Remove DimensionalData v0.29.25 compat * chore: Increment patch number for release * chore: Fix version specifiers * test: Only check that constructed ELPDData has at least all keys as arviz one * feat: Add good_k field to constructed ELPDData * feat: Add support for arviz v0.19 * feat: Add support for arviz v0.20 * feat: Add support for arviz v0.21 * feat: Add support for arviz v0.22 * Revert "chore: Fix version specifiers" This reverts commit be3bf86. * Revert "chore: Remove DimensionalData v0.29.25 compat" This reverts commit 2fe154d.
1 parent 068d9ba commit b0ee0ec

File tree

4 files changed

+8
-4
lines changed

4 files changed

+8
-4
lines changed

CondaPkg.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
pandas = ""
33
matplotlib = ""
44
xarray = ""
5-
arviz = ">=0.15.0,<=0.18"
5+
arviz = ">=0.15.0,<=0.22"

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ArviZPythonPlots"
22
uuid = "4a6e88f0-2c8e-11ee-0601-e94153f0eada"
33
authors = ["Seth Axen <seth@sethaxen.com>"]
4-
version = "0.1.12"
4+
version = "0.1.13"
55

66
[deps]
77
CondaPkg = "992eb4ea-22a4-4c89-a5bb-47a3300528ab"

src/conversions.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ function PythonCall.Py(d::PSISLOOResult)
44
psis_result = d.psis_result
55
ds = convert_to_dataset((loo_i=pointwise.elpd, pareto_shape=pointwise.pareto_shape))
66
pyds = PythonCall.Py(ds)
7+
n_samples = d.psis_result.ndraws * d.psis_result.nchains
8+
good_k = min(1 - inv(log10(n_samples)), 0.7)
9+
710
entries = (
811
elpd_loo=estimates.elpd,
912
se=estimates.se_elpd,
@@ -14,6 +17,7 @@ function PythonCall.Py(d::PSISLOOResult)
1417
loo_i=pyds.loo_i,
1518
pareto_k=pyds.pareto_shape,
1619
scale="log",
20+
good_k=good_k,
1721
)
1822
data = pylist(values(entries))
1923
index = pylist(map(pystr, keys(entries)))

test/test_conversions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ using Test
1010
loo_result = loo(idata; reff=1)
1111
loo_py_result = ArviZPythonPlots.arviz.loo(idata; pointwise=true, reff=1)
1212
py_loo_result = Py(loo_result)
13-
@test all(
14-
pyconvert(Array{String}, py_loo_result.keys()) ==
13+
@test issubset(
1514
pyconvert(Array{String}, loo_py_result.keys()),
15+
pyconvert(Array{String}, py_loo_result.keys()),
1616
)
1717
@test pyconvert(Float64, py_loo_result.elpd_loo)
1818
pyconvert(Float64, loo_py_result.elpd_loo) rtol = 1e-3

0 commit comments

Comments
 (0)