Skip to content

Commit 47faa3d

Browse files
committed
Updated unit tests and R package
1 parent 19b8645 commit 47faa3d

File tree

9 files changed

+148
-20
lines changed

9 files changed

+148
-20
lines changed

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ export(createRandomEffectSamples)
3838
export(createRandomEffectsDataset)
3939
export(createRandomEffectsModel)
4040
export(createRandomEffectsTracker)
41+
export(expand_dims_1d)
42+
export(expand_dims_2d)
43+
export(expand_dims_2d_diag)
4144
export(getRandomEffectSamples)
4245
export(loadForestContainerCombinedJson)
4346
export(loadForestContainerCombinedJsonString)

R/utils.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -892,12 +892,12 @@ expand_dims_1d <- function(input, output_size) {
892892
#' @export
893893
expand_dims_2d <- function(input, output_rows, output_cols) {
894894
if (length(input) == 1) {
895-
output <- as.matrix(rep(input, output_rows * output_cols), ncol = output_cols)
895+
output <- matrix(rep(input, output_rows * output_cols), ncol = output_cols)
896896
} else if (is.numeric(input)) {
897897
if (length(input) == output_cols) {
898-
output <- matrix(rep(x, output_rows), nrow=output_rows, byrow = T)
898+
output <- matrix(rep(input, output_rows), nrow=output_rows, byrow = T)
899899
} else if (length(input) == output_rows) {
900-
output <- matrix(rep(x, output_cols), ncol=output_cols, byrow = F)
900+
output <- matrix(rep(input, output_cols), ncol=output_cols, byrow = F)
901901
} else {
902902
stop("If `input` is a vector, it must either contain `output_rows` or `output_cols` elements")
903903
}

man/bart.Rd

Lines changed: 6 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/expand_dims_1d.Rd

Lines changed: 21 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/expand_dims_2d.Rd

Lines changed: 37 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/expand_dims_2d_diag.Rd

Lines changed: 21 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

stochtree/utils.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,7 @@ def _expand_dims_1d(input: Union[int, float, np.array], output_size: int) -> np.
192192
"""
193193
Convert scalar input to 1D numpy array of dimension `output_size`,
194194
or check that input array is equivalent to a 1D array of dimension `output_size`.
195+
Single element numpy arrays (i.e. `np.array([2.5])`) are treated as scalars.
195196
196197
Parameters
197198
----------
@@ -207,11 +208,14 @@ def _expand_dims_1d(input: Union[int, float, np.array], output_size: int) -> np.
207208
"""
208209
if isinstance(input, np.ndarray):
209210
input = np.squeeze(input)
210-
if input.ndim != 1:
211-
raise ValueError("`input` must be convertible to a 1D numpy array")
212-
if input.shape[0] != output_size:
213-
raise ValueError("`input` must be a 1D numpy array with `output_size` elements")
214-
output = input
211+
if input.ndim > 1:
212+
raise ValueError("`input` must be convertible to a 1D numpy array or scalar")
213+
if input.ndim == 0:
214+
output = np.repeat(input, output_size)
215+
else:
216+
if input.shape[0] != output_size:
217+
raise ValueError("`input` must be a 1D numpy array with `output_size` elements")
218+
output = input
215219
elif isinstance(input, (int, float)):
216220
output = np.repeat(input, output_size)
217221
else:
@@ -227,7 +231,7 @@ def _expand_dims_2d(input: Union[int, float, np.array], output_rows: int, output
227231
2. `input` is a 1D array of length `output_rows`: output is a (`output_rows`, `output_cols`) array with `input` broadcast across each of `output_cols` columns
228232
3. `input` is a 1D array of length `output_cols`: output is a (`output_rows`, `output_cols`) array with `input` broadcast across each of `output_rows` rows
229233
4. `input` is a 2D array of dimension (`output_rows`, `output_cols`): input is passed through as-is
230-
All other cases raise a `ValueError`.
234+
All other cases raise a `ValueError`. Single element numpy arrays (i.e. `np.array([2.5])`) are treated as scalars.
231235
232236
Parameters
233237
----------
@@ -273,6 +277,7 @@ def _expand_dims_2d_diag(input: Union[int, float, np.array], output_size: int) -
273277
"""
274278
Convert scalar input to 2D square numpy array of dimension `output_size` x `output_size` with `input` along the diagonal,
275279
or check that input array is equivalent to a 2D square array of dimension `output_size` x `output_size`.
280+
Single element numpy arrays (i.e. `np.array([2.5])`) are treated as scalars.
276281
277282
Parameters
278283
----------
@@ -288,13 +293,19 @@ def _expand_dims_2d_diag(input: Union[int, float, np.array], output_size: int) -
288293
"""
289294
if isinstance(input, np.ndarray):
290295
input = np.squeeze(input)
291-
if input.ndim != 2:
292-
raise ValueError("`input` must be convertible to a 2D numpy array")
293-
if input.shape[0] != input.shape[1]:
294-
raise ValueError("`input` must be a 2D square numpy array")
295-
if input.shape[0] != output_size:
296-
raise ValueError("`input` must be a 2D square numpy array with exactly `output_size` rows and columns")
297-
output = input
296+
if (input.ndim != 2) and (input.ndim != 0):
297+
raise ValueError("`input` must be convertible to a 2D numpy array or scalar")
298+
if input.ndim == 0:
299+
output = np.zeros(
300+
(output_size, output_size), dtype=float
301+
)
302+
np.fill_diagonal(output, input)
303+
else:
304+
if input.shape[0] != input.shape[1]:
305+
raise ValueError("`input` must be a 2D square numpy array")
306+
if input.shape[0] != output_size:
307+
raise ValueError("`input` must be a 2D square numpy array with exactly `output_size` rows and columns")
308+
output = input
298309
elif isinstance(input, (int, float)):
299310
output = np.zeros(
300311
(output_size, output_size), dtype=float

test/R/testthat/test-utils.R

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,45 @@ test_that("Array conversion", {
2626
)
2727
array_square_3 <- matrix(
2828
c(2.5,0.0,0.0,0.0,2.9,0.0,0.0,0.0,5.6),
29-
nrow = 2, ncol = 2, byrow = T
29+
nrow = 3, ncol = 3, byrow = T
3030
)
3131

3232
# Error cases
3333
expect_error(expand_dims_1d(array_1d_1, 5))
3434
expect_error(expand_dims_1d(array_1d_2, 4))
35-
expect_error(expand_dims_1d(array_1d_3, 3))
3635
expect_error(expand_dims_2d(array_2d_1, 2, 4))
3736
expect_error(expand_dims_2d(array_2d_2, 3, 4))
3837
expect_error(expand_dims_2d_diag(array_square_1, 4))
3938
expect_error(expand_dims_2d_diag(array_square_2, 3))
4039
expect_error(expand_dims_2d_diag(array_square_3, 2))
4140

42-
# # Assertion
43-
# expect_equal(y_hat_orig, y_hat_reloaded)
41+
# Working cases
42+
expect_equal(c(scalar_1,scalar_1,scalar_1), expand_dims_1d(scalar_1, 3))
43+
expect_equal(c(scalar_2,scalar_2,scalar_2,scalar_2), expand_dims_1d(scalar_2, 4))
44+
expect_equal(c(scalar_3,scalar_3), expand_dims_1d(scalar_3, 2))
45+
expect_equal(c(array_1d_3,array_1d_3,array_1d_3), expand_dims_1d(array_1d_3, 3))
46+
47+
output_exp <- matrix(rep(scalar_1, 6), nrow = 2, byrow = T)
48+
expect_equal(output_exp, expand_dims_2d(scalar_1, 2, 3))
49+
output_exp <- matrix(rep(scalar_2, 8), nrow = 2, byrow = T)
50+
expect_equal(output_exp, expand_dims_2d(scalar_2, 2, 4))
51+
output_exp <- matrix(rep(scalar_3, 6), nrow = 3, byrow = T)
52+
expect_equal(output_exp, expand_dims_2d(scalar_3, 3, 2))
53+
output_exp <- matrix(rep(array_1d_3, 6), nrow = 3, byrow = T)
54+
expect_equal(output_exp, expand_dims_2d(array_1d_3, 3, 2))
55+
output_exp <- unname(rbind(array_1d_1, array_1d_1))
56+
expect_equal(output_exp, expand_dims_2d(array_1d_1, 2, 4))
57+
output_exp <- unname(rbind(array_1d_2, array_1d_2, array_1d_2))
58+
expect_equal(output_exp, expand_dims_2d(array_1d_2, 3, 3))
59+
output_exp <- unname(cbind(array_1d_2, array_1d_2, array_1d_2, array_1d_2))
60+
expect_equal(output_exp, expand_dims_2d(array_1d_2, 3, 4))
61+
output_exp <- unname(cbind(array_1d_3, array_1d_3, array_1d_3, array_1d_3))
62+
expect_equal(output_exp, expand_dims_2d(array_1d_3, 1, 4))
63+
output_exp <- unname(rbind(array_1d_3, array_1d_3, array_1d_3, array_1d_3))
64+
expect_equal(output_exp, expand_dims_2d(array_1d_3, 4, 1))
65+
66+
expect_equal(diag(scalar_1, 3), expand_dims_2d_diag(scalar_1, 3))
67+
expect_equal(diag(scalar_2, 2), expand_dims_2d_diag(scalar_2, 2))
68+
expect_equal(diag(scalar_3, 4), expand_dims_2d_diag(scalar_3, 4))
69+
expect_equal(diag(array_1d_3, 2), expand_dims_2d_diag(array_1d_3, 2))
4470
})

test/python/test_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,12 @@ def test_array_conversion(self):
143143
np.testing.assert_array_equal(np.array([scalar_1,scalar_1,scalar_1]), _expand_dims_1d(scalar_1, 3))
144144
np.testing.assert_array_equal(np.array([scalar_2,scalar_2,scalar_2,scalar_2]), _expand_dims_1d(scalar_2, 4))
145145
np.testing.assert_array_equal(np.array([scalar_3,scalar_3]), _expand_dims_1d(scalar_3, 2))
146+
np.testing.assert_array_equal(np.array([array_1d_3[0],array_1d_3[0],array_1d_3[0]]), _expand_dims_1d(array_1d_3, 3))
146147

147148
np.testing.assert_array_equal(np.array([[scalar_1,scalar_1,scalar_1],[scalar_1,scalar_1,scalar_1]]), _expand_dims_2d(scalar_1, 2, 3))
148149
np.testing.assert_array_equal(np.array([[scalar_2,scalar_2,scalar_2,scalar_2],[scalar_2,scalar_2,scalar_2,scalar_2]]), _expand_dims_2d(scalar_2, 2, 4))
149150
np.testing.assert_array_equal(np.array([[scalar_3,scalar_3],[scalar_3,scalar_3],[scalar_3,scalar_3]]), _expand_dims_2d(scalar_3, 3, 2))
151+
np.testing.assert_array_equal(np.array([[array_1d_3[0],array_1d_3[0]],[array_1d_3[0],array_1d_3[0]],[array_1d_3[0],array_1d_3[0]]]), _expand_dims_2d(array_1d_3, 3, 2))
150152
np.testing.assert_array_equal(np.vstack((array_1d_1, array_1d_1)), _expand_dims_2d(array_1d_1, 2, 4))
151153
np.testing.assert_array_equal(np.vstack((array_1d_2, array_1d_2, array_1d_2)), _expand_dims_2d(array_1d_2, 3, 3))
152154
np.testing.assert_array_equal(np.column_stack((array_1d_2, array_1d_2, array_1d_2, array_1d_2)), _expand_dims_2d(array_1d_2, 3, 4))
@@ -156,3 +158,4 @@ def test_array_conversion(self):
156158
np.testing.assert_array_equal(np.array([[scalar_1,0.0,0.0],[0.0,scalar_1,0.0],[0.0,0.0,scalar_1]]), _expand_dims_2d_diag(scalar_1, 3))
157159
np.testing.assert_array_equal(np.array([[scalar_2,0.0],[0.0,scalar_2]]), _expand_dims_2d_diag(scalar_2, 2))
158160
np.testing.assert_array_equal(np.array([[scalar_3,0.0,0.0,0.0],[0.0,scalar_3,0.0,0.0],[0.0,0.0,scalar_3,0.0],[0.0,0.0,0.0,scalar_3]]), _expand_dims_2d_diag(scalar_3, 4))
161+
np.testing.assert_array_equal(np.array([[array_1d_3[0],0.0],[0.0,array_1d_3[0]]]), _expand_dims_2d_diag(array_1d_3, 2))

0 commit comments

Comments
 (0)