Skip to content

Commit 9c94e76

Browse files
committed
addressing PR reviews + fixed variable fetching in EWCRegularizer
1 parent 5fb7fc9 commit 9c94e76

File tree

3 files changed

+67
-58
lines changed

3 files changed

+67
-58
lines changed

neuralmonkey/trainers/cross_entropy_trainer.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import tensorflow as tf
44
from typeguard import check_argument_types
55

6-
from neuralmonkey.logging import warn
76
from neuralmonkey.trainers.generic_trainer import (
87
GenericTrainer, Objective, ObjectiveWeight)
98
from neuralmonkey.trainers.regularizers import (
@@ -42,17 +41,13 @@ def __init__(self,
4241

4342
if regularizers is None:
4443
regularizers = []
45-
if l1_weight > 0.:
46-
if L1Regularizer in [type(r) for r in regularizers]:
47-
warn("You specified both trainer l1_weight "
48-
"and a L1Regularizer object in your config")
49-
regularizers.append(L1Regularizer(weight=l1_weight))
5044

45+
if l1_weight > 0.:
46+
regularizers.append(
47+
L1Regularizer(name="train_l1", weight=l1_weight))
5148
if l2_weight > 0.:
52-
if L2Regularizer in [type(r) for r in regularizers]:
53-
warn("You specified both trainer l2_weight "
54-
"and a L2Regularizer object in your config")
55-
regularizers.append(L2Regularizer(weight=l2_weight))
49+
regularizers.append(
50+
L2Regularizer(name="train_l2", weight=l2_weight))
5651

5752
if len(decoder_weights) != len(decoders):
5853
raise ValueError(

neuralmonkey/trainers/generic_trainer.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
from neuralmonkey.model.model_part import ModelPart
88
from neuralmonkey.runners.base_runner import (
99
Executable, ExecutionResult, NextExecute)
10-
from neuralmonkey.trainers.regularizers import (
11-
Regularizer, L2Regularizer)
10+
from neuralmonkey.trainers.regularizers import (Regularizer, L2Regularizer)
1211

1312
# pylint: disable=invalid-name
1413
Gradients = List[Tuple[tf.Tensor, tf.Variable]]
@@ -40,6 +39,7 @@ class Objective(NamedTuple(
4039

4140

4241
# pylint: disable=too-few-public-methods,too-many-locals,too-many-branches
42+
# pylint: disable=too-many-statements
4343
class GenericTrainer:
4444

4545
def __init__(self,
@@ -102,7 +102,9 @@ def __init__(self,
102102

103103
# we always want to include l2 values in the summary
104104
if L2Regularizer not in [type(r) for r in self.regularizers]:
105-
reg_values.append(L2Regularizer().value(regularizable))
105+
l2_reg = L2Regularizer(name="train_l2", weight=0.)
106+
tf.summary.scalar(l2_reg.name, l2_reg.value(regularizable),
107+
collections=["summary_train"])
106108
for reg, reg_value in zip(self.regularizers, reg_values):
107109
tf.summary.scalar(reg.name, reg_value,
108110
collections=["summary_train"])
@@ -119,8 +121,8 @@ def __init__(self,
119121
with tf.name_scope("gradient_collection"):
120122
differentiable_loss_sum = sum(
121123
[(o.weight if o.weight is not None else 1.) * o.loss
122-
for o in objectives
123-
if o.gradients is None] + reg_costs)
124+
for o in objectives if o.gradients is None])
125+
differentiable_loss_sum += sum(reg_costs)
124126
implicit_gradients = self._get_gradients(
125127
differentiable_loss_sum)
126128

@@ -130,25 +132,24 @@ def __init__(self,
130132
for o in objectives if o.gradients is not None]
131133

132134
if other_gradients:
133-
gradients = _sum_gradients(
135+
self.gradients = _sum_gradients(
134136
[implicit_gradients] + other_gradients)
135137
else:
136-
gradients = implicit_gradients
138+
self.gradients = implicit_gradients
137139

138140
tf.summary.scalar("train_opt_cost",
139141
differentiable_loss_sum,
140142
collections=["summary_train"])
141143

142144
if clip_norm:
143145
assert clip_norm > 0.0
144-
gradients = [(tf.clip_by_norm(grad, clip_norm), var)
145-
for grad, var in gradients
146-
if grad is not None]
146+
self.gradients = [(tf.clip_by_norm(grad, clip_norm), var)
147+
for grad, var in self.gradients
148+
if grad is not None]
147149

148150
self.all_coders = set.union(*(obj.decoder.get_dependencies()
149151
for obj in objectives))
150152

151-
self.gradients = gradients
152153
self.train_op = self.optimizer.apply_gradients(
153154
self.gradients, global_step=step)
154155

neuralmonkey/trainers/regularizers.py

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
This module contains classes that can be used as a variable regularizers
44
during training. All implementation should be derived from the Regularizer
55
class.
6-
76
"""
7+
from abc import ABCMeta, abstractmethod
88
from typing import List
99

1010
import numpy as np
@@ -14,8 +14,14 @@
1414
from neuralmonkey.logging import log
1515

1616

17-
class Regularizer:
18-
"""Base class for the regularizers."""
17+
class Regularizer(metaclass=ABCMeta):
18+
"""Base clas s for regularizers.
19+
20+
Regularizer objects are used to introduce additional loss terms to
21+
the trainerthus constraining the model variable during training. These
22+
loss terms have an adjustable weight allowing to set the ``importance''
23+
of the term.
24+
"""
1925

2026
def __init__(self,
2127
name: str,
@@ -24,10 +30,9 @@ def __init__(self,
2430
2531
Arguments:
2632
name: Regularizer name.
27-
weight: Weight of the regularization term.
33+
weight: Weight of the regularization term (usually expressed
34+
as ``lambda'' in the literature).
2835
"""
29-
check_argument_types()
30-
3136
self._name = name
3237
self._weight = weight
3338

@@ -39,34 +44,40 @@ def name(self) -> str:
3944
def weight(self) -> float:
4045
return self._weight
4146

42-
def value(self, variables) -> float:
47+
@abstractmethod
48+
def value(self, variables: List[tf.Tensor]) -> tf.Tensor:
49+
"""Compute the unweighted value of the regularization loss term.
50+
51+
Arguments:
52+
variables: List of the regularizable model variables.
53+
"""
4354
raise NotImplementedError("Abstract method")
4455

4556

4657
class L1Regularizer(Regularizer):
4758
"""L1 regularizer."""
4859

4960
def __init__(self,
50-
name: str = "train_l1",
51-
weight: float = 1.0e-8) -> None:
61+
name: str,
62+
weight: float) -> None:
5263
"""Create the regularizer.
5364
5465
Arguments:
5566
name: Regularizer name.
56-
weight: Weight of the regularization term.
67+
weight: Weight of the regularization term (default=1.0e-8.
5768
"""
5869
Regularizer.__init__(self, name, weight)
5970

60-
def value(self, variables: List[tf.Tensor]) -> float:
71+
def value(self, variables: List[tf.Tensor]) -> tf.Tensor:
6172
return sum(tf.reduce_sum(abs(v)) for v in variables)
6273

6374

6475
class L2Regularizer(Regularizer):
6576
"""L2 regularizer."""
6677

6778
def __init__(self,
68-
name: str = "train_l2",
69-
weight: float = 1.0e-8) -> None:
79+
name: str,
80+
weight: float) -> None:
7081
"""Create the regularizer.
7182
7283
Arguments:
@@ -75,7 +86,7 @@ def __init__(self,
7586
"""
7687
Regularizer.__init__(self, name, weight)
7788

78-
def value(self, variables: List[tf.Tensor]) -> float:
89+
def value(self, variables: List[tf.Tensor]) -> tf.Tensor:
7990
return sum(tf.reduce_sum(v ** 2) for v in variables)
8091

8192

@@ -84,15 +95,18 @@ class EWCRegularizer(Regularizer):
8495
8596
Implements Elastic Weight Consolidation from the "Overcoming catastrophic
8697
forgetting in neural networks" paper.
98+
The regularizer applies separate regularization weight to each trainable
99+
variable based on how important the variable was for the previously
100+
learned task.
87101
88102
https://arxiv.org/pdf/1612.00796.pdf
89103
"""
90104

91105
def __init__(self,
92-
name: str = "train_ewc",
93-
weight: float = 0.,
94-
gradients_file: str = None,
95-
variables_file: str = None) -> None:
106+
name: str,
107+
weight: float,
108+
gradients_file: str,
109+
variables_file: str) -> None:
96110
"""Create the regularizer.
97111
98112
Arguments:
@@ -104,36 +118,35 @@ def __init__(self,
104118
on the previous task.
105119
"""
106120
check_argument_types()
107-
108121
Regularizer.__init__(self, name, weight)
109122

110-
if gradients_file is None:
111-
raise ValueError("Missing gradients_file")
112-
if variables_file is None:
113-
raise ValueError("Missing variables_file")
114-
115-
log("Loading initial variables for EWC from {}".format(variables_file))
123+
log("Loading initial variables for EWC from "
124+
"{}.".format(variables_file))
116125
self.init_vars = tf.contrib.framework.load_checkpoint(variables_file)
117-
log("EWC initial variables loaded")
126+
log("EWC initial variables loaded.")
118127

119-
log("Loading gradient estimates from {}".format(gradients_file))
128+
log("Loading gradient estimates from {}.".format(gradients_file))
120129
self.gradients = np.load(gradients_file)
121-
log("Gradient estimates loaded")
130+
log("Gradient estimates loaded.")
122131

123-
def value(self, variables: List[tf.Tensor]) -> float:
132+
def value(self, variables: List[tf.Tensor]) -> tf.Tensor:
124133
ewc_value = tf.constant(0.0)
125134
for var in variables:
126-
var_name = var.name.split(":")[0]
135+
var_name = var.name
136+
init_var_name = var_name.split(":")[0]
127137
if (var_name in self.gradients.files
128-
and self.init_vars.has_tensor(var_name)):
129-
init_var = self.init_vars.get_tensor(var_name)
130-
gradient = tf.constant(
131-
self.gradients[var_name], name="ewc_gradients")
138+
and self.init_vars.has_tensor(init_var_name)):
139+
init_var = tf.constant(
140+
self.init_vars.get_tensor(init_var_name),
141+
name="{}_init_value".format(init_var_name))
142+
grad_squared = tf.constant(
143+
np.square(self.gradients[var_name]),
144+
name="{}_ewc_weight".format(init_var_name))
132145
ewc_value += tf.reduce_sum(tf.multiply(
133-
tf.square(gradient), tf.square(var - init_var)))
146+
grad_squared, tf.square(var - init_var)))
134147

135148
return ewc_value
136149

137150

138-
L1 = L1Regularizer()
139-
L2 = L2Regularizer()
151+
L1 = L1Regularizer(name="train_l1", weight=1.0e-8)
152+
L2 = L2Regularizer(name="train_l2", weight=1.0e-8)

0 commit comments

Comments
 (0)