Skip to content

Commit 32f13a6

Browse files
committed
Reformat other python source files
1 parent 78ceee2 commit 32f13a6

File tree

9 files changed

+404
-225
lines changed

9 files changed

+404
-225
lines changed

stochtree/bcf.py

Lines changed: 137 additions & 90 deletions
Large diffs are not rendered by default.

stochtree/config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ def __init__(
119119
raise ValueError("`leaf_dimension` must be an integer greater than 0")
120120
if leaf_model_scale is None:
121121
diag_value = 1.0 / num_trees
122-
leaf_model_scale_array = np.zeros((leaf_dimension, leaf_dimension), dtype=float)
122+
leaf_model_scale_array = np.zeros(
123+
(leaf_dimension, leaf_dimension), dtype=float
124+
)
123125
np.fill_diagonal(leaf_model_scale_array, diag_value)
124126
else:
125127
if isinstance(leaf_model_scale, np.ndarray):
@@ -432,7 +434,7 @@ def get_feature_types(self) -> np.ndarray:
432434
"""
433435
return self.feature_types
434436

435-
def get_sweep_update_indices(self) -> Union[np.ndarray,None]:
437+
def get_sweep_update_indices(self) -> Union[np.ndarray, None]:
436438
"""
437439
Query vector of (0-indexed) indices of trees to update in a sweep
438440

stochtree/data.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def update_basis(self, basis: np.array):
5959
Numpy array of basis vectors.
6060
"""
6161
if not self.has_basis():
62-
raise ValueError("This dataset does not have a basis to update. Please use `add_basis` to create and initialize the values in the Dataset's basis matrix.")
62+
raise ValueError(
63+
"This dataset does not have a basis to update. Please use `add_basis` to create and initialize the values in the Dataset's basis matrix."
64+
)
6365
if not isinstance(basis, np.ndarray):
6466
raise ValueError("basis must be a numpy array.")
6567
if np.ndim(basis) == 1:
@@ -71,9 +73,13 @@ def update_basis(self, basis: np.array):
7173
n, p = basis_.shape
7274
basis_rowmajor = np.ascontiguousarray(basis_)
7375
if self.num_basis() != p:
74-
raise ValueError(f"The number of columns in the new basis ({p}) must match the number of columns in the existing basis ({self.num_basis()}).")
76+
raise ValueError(
77+
f"The number of columns in the new basis ({p}) must match the number of columns in the existing basis ({self.num_basis()})."
78+
)
7579
if self.num_observations() != n:
76-
raise ValueError(f"The number of rows in the new basis ({n}) must match the number of rows in the existing basis ({self.num_observations()}).")
80+
raise ValueError(
81+
f"The number of rows in the new basis ({n}) must match the number of rows in the existing basis ({self.num_observations()})."
82+
)
7783
self.dataset_cpp.UpdateBasis(basis_rowmajor, n, p, True)
7884

7985
def add_variance_weights(self, variance_weights: np.array):
@@ -91,12 +97,14 @@ def add_variance_weights(self, variance_weights: np.array):
9197
n = variance_weights_.size
9298
if variance_weights_.ndim != 1:
9399
raise ValueError("variance_weights must be a 1-dimensional numpy array.")
94-
100+
95101
self.dataset_cpp.AddVarianceWeights(variance_weights_, n)
96-
97-
def update_variance_weights(self, variance_weights: np.array, exponentiate: bool = False):
102+
103+
def update_variance_weights(
104+
self, variance_weights: np.array, exponentiate: bool = False
105+
):
98106
"""
99-
Update variance weights in a dataset. Allows users to build an ensemble that depends on
107+
Update variance weights in a dataset. Allows users to build an ensemble that depends on
100108
variance weights that are updated throughout the sampler.
101109
102110
Parameters
@@ -107,15 +115,19 @@ def update_variance_weights(self, variance_weights: np.array, exponentiate: bool
107115
Whether to exponentiate the variance weights before storing them in the dataset.
108116
"""
109117
if not self.has_variance_weights():
110-
raise ValueError("This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector.")
118+
raise ValueError(
119+
"This dataset does not have variance weights to update. Please use `add_variance_weights` to create and initialize the values in the Dataset's variance weight vector."
120+
)
111121
if not isinstance(variance_weights, np.ndarray):
112122
raise ValueError("variance_weights must be a numpy array.")
113123
variance_weights_ = np.squeeze(variance_weights)
114124
n = variance_weights_.size
115125
if variance_weights_.ndim != 1:
116126
raise ValueError("variance_weights must be a 1-dimensional numpy array.")
117127
if self.num_observations() != n:
118-
raise ValueError(f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()}).")
128+
raise ValueError(
129+
f"The number of rows in the new variance_weights vector ({n}) must match the number of rows in the existing vector ({self.num_observations()})."
130+
)
119131
self.dataset_cpp.UpdateVarianceWeights(variance_weights_, n, exponentiate)
120132

121133
def num_observations(self) -> int:
@@ -150,7 +162,7 @@ def num_basis(self) -> int:
150162
Dimension of the basis vector in the dataset, returning 0 if the dataset does not have a basis
151163
"""
152164
return self.dataset_cpp.NumBasis()
153-
165+
154166
def get_covariates(self) -> np.array:
155167
"""
156168
Return the covariates in a Dataset as a numpy array
@@ -161,7 +173,7 @@ def get_covariates(self) -> np.array:
161173
Covariate data
162174
"""
163175
return self.dataset_cpp.GetCovariates()
164-
176+
165177
def get_basis(self) -> np.array:
166178
"""
167179
Return the bases in a Dataset as a numpy array
@@ -172,7 +184,7 @@ def get_basis(self) -> np.array:
172184
Basis data
173185
"""
174186
return self.dataset_cpp.GetBasis()
175-
187+
176188
def get_variance_weights(self) -> np.array:
177189
"""
178190
Return the variance weights in a Dataset as a numpy array

stochtree/forest.py

Lines changed: 67 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -161,13 +161,13 @@ def set_root_leaves(
161161

162162
def collapse(self, batch_size: int) -> None:
163163
"""
164-
Collapse forests in this container by a pre-specified batch size.
165-
For example, if we have a container of twenty 10-tree forests, and we
166-
specify a `batch_size` of 5, then this method will yield four 50-tree
167-
forests. "Excess" forests remaining after the size of a forest container
168-
is divided by `batch_size` will be pruned from the beginning of the
169-
container (i.e. earlier sampled forests will be deleted). This method
170-
has no effect if `batch_size` is larger than the number of forests
164+
Collapse forests in this container by a pre-specified batch size.
165+
For example, if we have a container of twenty 10-tree forests, and we
166+
specify a `batch_size` of 5, then this method will yield four 50-tree
167+
forests. "Excess" forests remaining after the size of a forest container
168+
is divided by `batch_size` will be pruned from the beginning of the
169+
container (i.e. earlier sampled forests will be deleted). This method
170+
has no effect if `batch_size` is larger than the number of forests
171171
in a container.
172172
173173
Parameters
@@ -177,12 +177,23 @@ def collapse(self, batch_size: int) -> None:
177177
"""
178178
container_size = self.num_samples()
179179
if batch_size <= container_size and batch_size > 1:
180-
reverse_container_inds = np.linspace(start=container_size, stop=1, num=container_size, dtype=int)
180+
reverse_container_inds = np.linspace(
181+
start=container_size, stop=1, num=container_size, dtype=int
182+
)
181183
num_clean_batches = container_size // batch_size
182-
batch_inds = (reverse_container_inds - (container_size - ((container_size // num_clean_batches) * num_clean_batches)) - 1) // batch_size
184+
batch_inds = (
185+
reverse_container_inds
186+
- (
187+
container_size
188+
- ((container_size // num_clean_batches) * num_clean_batches)
189+
)
190+
- 1
191+
) // batch_size
183192
batch_inds = batch_inds.astype(int)
184193
for batch_ind in np.flip(np.unique(batch_inds[batch_inds >= 0])):
185-
merge_forest_inds = np.sort(reverse_container_inds[batch_inds == batch_ind] - 1)
194+
merge_forest_inds = np.sort(
195+
reverse_container_inds[batch_inds == batch_ind] - 1
196+
)
186197
num_merge_forests = len(merge_forest_inds)
187198
self.combine_forests(merge_forest_inds)
188199
for i in range(num_merge_forests - 1, 0, -1):
@@ -194,10 +205,8 @@ def collapse(self, batch_size: int) -> None:
194205
num_delete_forests = len(delete_forest_inds)
195206
for i in range(num_delete_forests - 1, -1, -1):
196207
self.delete_sample(delete_forest_inds[i])
197-
198-
def combine_forests(
199-
self, forest_inds: np.array
200-
) -> None:
208+
209+
def combine_forests(self, forest_inds: np.array) -> None:
201210
"""
202211
Collapse specified forests into a single forest
203212
@@ -214,42 +223,56 @@ def combine_forests(
214223
forest_inds_sorted = forest_inds_sorted.astype(int)
215224
self.forest_container_cpp.CombineForests(forest_inds_sorted)
216225

217-
def add_to_forest(
218-
self, forest_index: int, constant_value : float
219-
) -> None:
226+
def add_to_forest(self, forest_index: int, constant_value: float) -> None:
220227
"""
221228
Add a constant value to every leaf of every tree of a given forest
222229
223230
Parameters
224231
----------
225-
forest_index : int
232+
forest_index : int
226233
Index of forest whose leaves will be modified (0-indexed)
227-
constant_value : float
234+
constant_value : float
228235
Value to add to every leaf of every tree of the forest at `forest_index`
229236
"""
230-
if not isinstance(forest_index, int) and not isinstance(constant_value, (int, float)):
231-
raise ValueError("forest_index must be an integer and constant_multiple must be a float or int")
232-
if not forest_index >= 0 or not forest_index < self.forest_container_cpp.NumSamples():
233-
raise ValueError("forest_index must be >= 0 and less than the total number of samples in a forest container")
237+
if not isinstance(forest_index, int) and not isinstance(
238+
constant_value, (int, float)
239+
):
240+
raise ValueError(
241+
"forest_index must be an integer and constant_multiple must be a float or int"
242+
)
243+
if (
244+
not forest_index >= 0
245+
or not forest_index < self.forest_container_cpp.NumSamples()
246+
):
247+
raise ValueError(
248+
"forest_index must be >= 0 and less than the total number of samples in a forest container"
249+
)
234250
self.forest_container_cpp.AddToForest(forest_index, constant_value)
235251

236-
def multiply_forest(
237-
self, forest_index: int, constant_multiple : float
238-
) -> None:
252+
def multiply_forest(self, forest_index: int, constant_multiple: float) -> None:
239253
"""
240254
Multiply every leaf of every tree of a given forest by constant value
241255
242256
Parameters
243257
----------
244-
forest_index : int
258+
forest_index : int
245259
Index of forest whose leaves will be modified (0-indexed)
246-
constant_multiple : float
260+
constant_multiple : float
247261
Value to multiply through by every leaf of every tree of the forest at `forest_index`
248262
"""
249-
if not isinstance(forest_index, int) and not isinstance(constant_multiple, (int, float)):
250-
raise ValueError("forest_index must be an integer and constant_multiple must be a float or int")
251-
if not forest_index >= 0 or not forest_index < self.forest_container_cpp.NumSamples():
252-
raise ValueError("forest_index must be >= 0 and less than the total number of samples in a forest container")
263+
if not isinstance(forest_index, int) and not isinstance(
264+
constant_multiple, (int, float)
265+
):
266+
raise ValueError(
267+
"forest_index must be an integer and constant_multiple must be a float or int"
268+
)
269+
if (
270+
not forest_index >= 0
271+
or not forest_index < self.forest_container_cpp.NumSamples()
272+
):
273+
raise ValueError(
274+
"forest_index must be >= 0 and less than the total number of samples in a forest container"
275+
)
253276
self.forest_container_cpp.MultiplyForest(forest_index, constant_multiple)
254277

255278
def save_to_json_file(self, json_filename: str) -> None:
@@ -1021,7 +1044,7 @@ def set_root_leaves(self, leaf_value: Union[float, np.array]) -> None:
10211044
else:
10221045
self.forest_cpp.SetRootValue(leaf_value)
10231046
self.internal_forest_is_empty = False
1024-
1047+
10251048
def merge_forest(self, other_forest):
10261049
"""
10271050
Create a larger forest by merging the trees of this forest with those of another forest
@@ -1034,13 +1057,19 @@ def merge_forest(self, other_forest):
10341057
if not isinstance(other_forest, Forest):
10351058
raise ValueError("other_forest must be an instance of the Forest class")
10361059
if self.leaf_constant != other_forest.leaf_constant:
1037-
raise ValueError("Forests must have matching leaf dimensions in order to be merged")
1060+
raise ValueError(
1061+
"Forests must have matching leaf dimensions in order to be merged"
1062+
)
10381063
if self.output_dimension != other_forest.output_dimension:
1039-
raise ValueError("Forests must have matching leaf dimensions in order to be merged")
1064+
raise ValueError(
1065+
"Forests must have matching leaf dimensions in order to be merged"
1066+
)
10401067
if self.is_exponentiated != other_forest.is_exponentiated:
1041-
raise ValueError("Forests must have matching leaf dimensions in order to be merged")
1068+
raise ValueError(
1069+
"Forests must have matching leaf dimensions in order to be merged"
1070+
)
10421071
self.forest_cpp.MergeForest(other_forest.forest_cpp)
1043-
1072+
10441073
def add_constant(self, constant_value):
10451074
"""
10461075
Add a constant value to every leaf of every tree in an ensemble. If leaves are multi-dimensional, `constant_value` will be added to every dimension of the leaves.
@@ -1051,7 +1080,7 @@ def add_constant(self, constant_value):
10511080
Value that will be added to every leaf of every tree
10521081
"""
10531082
self.forest_cpp.AddConstant(constant_value)
1054-
1083+
10551084
def multiply_constant(self, constant_multiple):
10561085
"""
10571086
Multiply every leaf of every tree by a constant value. If leaves are multi-dimensional, `constant_multiple` will be multiplied through every dimension of the leaves.

0 commit comments

Comments
 (0)