Skip to content

Commit a2fa6bc

Browse files
committed
removed squaring of gradients in EWCRegularizer
1 parent 99838e3 commit a2fa6bc

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

neuralmonkey/trainers/regularizers.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,15 @@ class EWCRegularizer(Regularizer):
104104
def __init__(self,
105105
name: str,
106106
weight: float,
107-
gradients_file: str,
107+
fisher_file: str,
108108
variables_file: str) -> None:
109109
"""Create the regularizer.
110110
111111
Arguments:
112112
name: Regularizer name.
113113
weight: Weight of the regularization term.
114-
gradients_file: File containing the gradient estimates
115-
from the previous task.
114+
fisher_file: File containing the diagonal of the fisher information
115+
matrix estimated on the previous task.
116116
variables_files: File containing the variables learned
117117
on the previous task.
118118
"""
@@ -124,23 +124,28 @@ def __init__(self,
124124
self.init_vars = tf.contrib.framework.load_checkpoint(variables_file)
125125
log("EWC initial variables loaded.")
126126

127-
log("Loading gradient estimates from {}.".format(gradients_file))
128-
self.gradients = np.load(gradients_file)
127+
log("Loading gradient estimates from {}.".format(fisher_file))
128+
self.fisher = np.load(fisher_file)
129129
log("Gradient estimates loaded.")
130130

131131
def value(self, variables: List[tf.Tensor]) -> tf.Tensor:
132+
r"""Compute the value of the regularization term.
133+
134+
value = \sum_{i} (λ * F_{i} * (θ_{i} - θ_{i}^{*})^2)
135+
136+
where λ is the regularizer weight and F is the diagonal
137+
of the Fisher Information matrix.
138+
"""
139+
132140
ewc_value = tf.constant(0.0)
133141
for var in variables:
134142
init_var_name = var.name.split(":")[0]
135-
if (var.name in self.gradients.files
143+
if (var.name in self.fisher.files
136144
and self.init_vars.has_tensor(init_var_name)):
137145
init_var = tf.constant(
138146
self.init_vars.get_tensor(init_var_name),
139147
name="{}_init_value".format(init_var_name))
140-
grad_squared = tf.constant(
141-
np.square(self.gradients[var.name]),
142-
name="{}_ewc_weight".format(init_var_name))
143148
ewc_value += tf.reduce_sum(tf.multiply(
144-
grad_squared, tf.square(var - init_var)))
149+
self.fisher[var.name], tf.square(var - init_var)))
145150

146151
return ewc_value

0 commit comments

Comments
 (0)