Skip to content

Commit bb7b988

Browse files
by256xuyxu
andauthored
feat: support arbitrary training criteria (#86)
* feat: support arbitrary training criteria * feat: support arbitrary training criteria (update docs to include step) * chore: support arbitrary training criteria (clean up to adhere to black code format) * finish code Co-authored-by: xuyxu <xuyx@lamda.nju.edu.cn>
1 parent 13f4b84 commit bb7b988

File tree

14 files changed

+265
-89
lines changed

14 files changed

+265
-89
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ Changelog
1818
Ver 0.1.*
1919
---------
2020

21+
* |Feature| |API| Support arbitrary training criteria for all ensembles except Gradient Boosting | `@by256 <https://github.com/by256>`__ and `@xuyxu <https://github.com/xuyxu>`__
2122
* |Fix| Fix missing functionality of ``save_model`` for :meth:`fit` of Soft Gradient Boosting | `@xuyxu <https://github.com/xuyxu>`__
2223
* |Feature| |API| Add :class:`SoftGradientBoostingClassifier` and :class:`SoftGradientBoostingRegressor` | `@xuyxu <https://github.com/xuyxu>`__
2324
* |Feature| |API| Support using dataloader with multiple input | `@xuyxu <https://github.com/xuyxu>`__

docs/index.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ Example
3434
estimator=base_estimator, # here is your deep learning model
3535
n_estimators=10, # number of base estimators
3636
)
37+
# Set the criterion
38+
criterion = nn.CrossEntropyLoss() # training objective
39+
ensemble.set_criterion(criterion)
3740
3841
# Set the optimizer
3942
ensemble.set_optimizer(

docs/quick_start.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,17 @@ The meaning of different arguments is listed as follow:
8282
* ``n_estimators``: The number of base estimators in the ensemble.
8383
* ``cuda``: Specify whether to use GPU for training and evaluating the ensemble.
8484

85+
Set the Criterion
86+
-----------------
87+
88+
The next step is to set the objective function. Since our ensemble model is a classifier, we
89+
will use cross-entropy:
90+
91+
.. code-block:: python
92+
93+
criterion = nn.CrossEntropyLoss()
94+
model.set_criterion(criterion)
95+
8596
Set the Optimizer
8697
-----------------
8798

@@ -180,6 +191,10 @@ The script below shows an example on using VotingClassifier with 10 MLPs for cla
180191
cuda=True,
181192
)
182193
194+
# Set the criterion
195+
criterion = nn.CrossEntropyLoss()
196+
model.set_criterion(criterion)
197+
183198
# Set the optimizer
184199
model.set_optimizer('Adam', lr=1e-3, weight_decay=5e-4)
185200

torchensemble/_base.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def get_doc(item):
3333
"predict": const.__predict_doc,
3434
"set_optimizer": const.__set_optimizer_doc,
3535
"set_scheduler": const.__set_scheduler_doc,
36+
"set_criterion": const.__set_criterion_doc,
3637
"classifier_forward": const.__classification_forward_doc,
3738
"classifier_evaluate": const.__classification_evaluate_doc,
3839
"regressor_forward": const.__regression_forward_doc,
@@ -137,6 +138,10 @@ def _validate_parameters(self, epochs, log_interval):
137138
self.logger.error(msg.format(log_interval))
138139
raise ValueError(msg.format(log_interval))
139140

141+
def set_criterion(self, criterion):
142+
"""Set the training criterion."""
143+
self._criterion = criterion
144+
140145
def set_optimizer(self, optimizer_name, **kwargs):
141146
"""Set the parameter optimizer."""
142147
self.optimizer_name = optimizer_name
@@ -226,7 +231,6 @@ def evaluate(self, test_loader, return_loss=False):
226231
self.eval()
227232
correct = 0
228233
total = 0
229-
criterion = nn.CrossEntropyLoss()
230234
loss = 0.0
231235

232236
for _, elem in enumerate(test_loader):
@@ -235,7 +239,7 @@ def evaluate(self, test_loader, return_loss=False):
235239
_, predicted = torch.max(output.data, 1)
236240
correct += (predicted == target).sum().item()
237241
total += target.size(0)
238-
loss += criterion(output, target)
242+
loss += self._criterion(output, target)
239243

240244
acc = 100 * correct / total
241245
loss /= len(test_loader)
@@ -273,12 +277,11 @@ def _decide_n_outputs(self, train_loader):
273277
def evaluate(self, test_loader):
274278
"""Docstrings decorated by downstream ensembles."""
275279
self.eval()
276-
mse = 0.0
277-
criterion = nn.MSELoss()
280+
loss = 0.0
278281

279282
for _, elem in enumerate(test_loader):
280283
data, target = split_data_target(elem, self.device)
281284
output = self.forward(*data)
282-
mse += criterion(output, target)
285+
loss += self._criterion(output, target)
283286

284-
return float(mse) / len(test_loader)
287+
return float(loss) / len(test_loader)

torchensemble/_constants.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@
8585
"""
8686

8787

88+
__set_criterion_doc = """
89+
Parameters
90+
----------
91+
criterion : torch.nn.loss
92+
The customized training criterion object.
93+
"""
94+
95+
8896
__fit_doc = """
8997
Parameters
9098
----------

torchensemble/adversarial_training.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,13 @@ def set_optimizer(self, optimizer_name, **kwargs):
246246
def set_scheduler(self, scheduler_name, **kwargs):
247247
super().set_scheduler(scheduler_name, **kwargs)
248248

249+
@torchensemble_model_doc(
250+
"""Set the training criterion for AdversarialTrainingClassifier.""",
251+
"set_criterion",
252+
)
253+
def set_criterion(self, criterion):
254+
super().set_criterion(criterion)
255+
249256
@_adversarial_training_model_doc(
250257
"""Implementation on the training stage of AdversarialTrainingClassifier.""", # noqa: E501
251258
"fit",
@@ -282,8 +289,11 @@ def fit(
282289
optimizers[0], self.scheduler_name, **self.scheduler_args
283290
)
284291

292+
# Check the training criterion
293+
if not hasattr(self, "_criterion"):
294+
self._criterion = nn.CrossEntropyLoss()
295+
285296
# Utils
286-
criterion = nn.CrossEntropyLoss()
287297
best_acc = 0.0
288298

289299
# Internal helper function on pesudo forward
@@ -318,7 +328,7 @@ def _forward(estimators, *x):
318328
estimator,
319329
cur_lr,
320330
optimizer,
321-
criterion,
331+
self._criterion,
322332
idx,
323333
epoch,
324334
log_interval,
@@ -424,6 +434,13 @@ def set_optimizer(self, optimizer_name, **kwargs):
424434
def set_scheduler(self, scheduler_name, **kwargs):
425435
super().set_scheduler(scheduler_name, **kwargs)
426436

437+
@torchensemble_model_doc(
438+
"""Set the training criterion for AdversarialTrainingRegressor.""",
439+
"set_criterion",
440+
)
441+
def set_criterion(self, criterion):
442+
super().set_criterion(criterion)
443+
427444
@_adversarial_training_model_doc(
428445
"""Implementation on the training stage of AdversarialTrainingRegressor.""", # noqa: E501
429446
"fit",
@@ -460,9 +477,12 @@ def fit(
460477
optimizers[0], self.scheduler_name, **self.scheduler_args
461478
)
462479

480+
# Check the training criterion
481+
if not hasattr(self, "_criterion"):
482+
self._criterion = nn.MSELoss()
483+
463484
# Utils
464-
criterion = nn.MSELoss()
465-
best_mse = float("inf")
485+
best_loss = float("inf")
466486

467487
# Internal helper function on pesudo forward
468488
def _forward(estimators, *x):
@@ -494,7 +514,7 @@ def _forward(estimators, *x):
494514
estimator,
495515
cur_lr,
496516
optimizer,
497-
criterion,
517+
self._criterion,
498518
idx,
499519
epoch,
500520
log_interval,
@@ -515,31 +535,33 @@ def _forward(estimators, *x):
515535
if test_loader:
516536
self.eval()
517537
with torch.no_grad():
518-
mse = 0.0
538+
val_loss = 0.0
519539
for _, elem in enumerate(test_loader):
520540
data, target = io.split_data_target(
521541
elem, self.device
522542
)
523543
output = _forward(estimators, *data)
524-
mse += criterion(output, target)
525-
mse /= len(test_loader)
544+
val_loss += self._criterion(output, target)
545+
val_loss /= len(test_loader)
526546

527-
if mse < best_mse:
528-
best_mse = mse
547+
if val_loss < best_loss:
548+
best_loss = val_loss
529549
self.estimators_ = nn.ModuleList()
530550
self.estimators_.extend(estimators)
531551
if save_model:
532552
io.save(self, save_dir, self.logger)
533553

534554
msg = (
535-
"Epoch: {:03d} | Validation MSE:"
555+
"Epoch: {:03d} | Validation Loss:"
536556
" {:.5f} | Historical Best: {:.5f}"
537557
)
538-
self.logger.info(msg.format(epoch, mse, best_mse))
558+
self.logger.info(
559+
msg.format(epoch, val_loss, best_loss)
560+
)
539561
if self.tb_logger:
540562
self.tb_logger.add_scalar(
541-
"adversirial_training/Validation_MSE",
542-
mse,
563+
"adversirial_training/Validation_Loss",
564+
val_loss,
543565
epoch,
544566
)
545567

torchensemble/bagging.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,13 @@ def set_optimizer(self, optimizer_name, **kwargs):
123123
def set_scheduler(self, scheduler_name, **kwargs):
124124
super().set_scheduler(scheduler_name, **kwargs)
125125

126+
@torchensemble_model_doc(
127+
"""Set the training criterion for BaggingClassifier.""",
128+
"set_criterion",
129+
)
130+
def set_criterion(self, criterion):
131+
super().set_criterion(criterion)
132+
126133
@torchensemble_model_doc(
127134
"""Implementation on the training stage of BaggingClassifier.""", "fit"
128135
)
@@ -157,8 +164,11 @@ def fit(
157164
optimizers[0], self.scheduler_name, **self.scheduler_args
158165
)
159166

167+
# Check the training criterion
168+
if not hasattr(self, "_criterion"):
169+
self._criterion = nn.CrossEntropyLoss()
170+
160171
# Utils
161-
criterion = nn.CrossEntropyLoss()
162172
best_acc = 0.0
163173

164174
# Internal helper function on pesudo forward
@@ -192,7 +202,7 @@ def _forward(estimators, *x):
192202
estimator,
193203
cur_lr,
194204
optimizer,
195-
criterion,
205+
self._criterion,
196206
idx,
197207
epoch,
198208
log_interval,
@@ -295,6 +305,13 @@ def set_optimizer(self, optimizer_name, **kwargs):
295305
def set_scheduler(self, scheduler_name, **kwargs):
296306
super().set_scheduler(scheduler_name, **kwargs)
297307

308+
@torchensemble_model_doc(
309+
"""Set the training criterion for BaggingRegressor.""",
310+
"set_criterion",
311+
)
312+
def set_criterion(self, criterion):
313+
super().set_criterion(criterion)
314+
298315
@torchensemble_model_doc(
299316
"""Implementation on the training stage of BaggingRegressor.""", "fit"
300317
)
@@ -329,9 +346,12 @@ def fit(
329346
optimizers[0], self.scheduler_name, **self.scheduler_args
330347
)
331348

349+
# Check the training criterion
350+
if not hasattr(self, "_criterion"):
351+
self._criterion = nn.MSELoss()
352+
332353
# Utils
333-
criterion = nn.MSELoss()
334-
best_mse = float("inf")
354+
best_loss = float("inf")
335355

336356
# Internal helper function on pesudo forward
337357
def _forward(estimators, *x):
@@ -362,7 +382,7 @@ def _forward(estimators, *x):
362382
estimator,
363383
cur_lr,
364384
optimizer,
365-
criterion,
385+
self._criterion,
366386
idx,
367387
epoch,
368388
log_interval,
@@ -383,30 +403,32 @@ def _forward(estimators, *x):
383403
if test_loader:
384404
self.eval()
385405
with torch.no_grad():
386-
mse = 0.0
406+
val_loss = 0.0
387407
for _, elem in enumerate(test_loader):
388408
data, target = io.split_data_target(
389409
elem, self.device
390410
)
391411
output = _forward(estimators, *data)
392-
mse += criterion(output, target)
393-
mse /= len(test_loader)
412+
val_loss += self._criterion(output, target)
413+
val_loss /= len(test_loader)
394414

395-
if mse < best_mse:
396-
best_mse = mse
415+
if val_loss < best_loss:
416+
best_loss = val_loss
397417
self.estimators_ = nn.ModuleList()
398418
self.estimators_.extend(estimators)
399419
if save_model:
400420
io.save(self, save_dir, self.logger)
401421

402422
msg = (
403-
"Epoch: {:03d} | Validation MSE:"
423+
"Epoch: {:03d} | Validation Loss:"
404424
" {:.5f} | Historical Best: {:.5f}"
405425
)
406-
self.logger.info(msg.format(epoch, mse, best_mse))
426+
self.logger.info(
427+
msg.format(epoch, val_loss, best_loss)
428+
)
407429
if self.tb_logger:
408430
self.tb_logger.add_scalar(
409-
"bagging/Validation_MSE", mse, epoch
431+
"bagging/Validation_Loss", val_loss, epoch
410432
)
411433

412434
# Update the scheduler

0 commit comments

Comments
 (0)