@@ -104,15 +104,15 @@ class EWCRegularizer(Regularizer):
104104 def __init__ (self ,
105105 name : str ,
106106 weight : float ,
107- gradients_file : str ,
107+ fisher_file : str ,
108108 variables_file : str ) -> None :
109109 """Create the regularizer.
110110
111111 Arguments:
112112 name: Regularizer name.
113113 weight: Weight of the regularization term.
114- gradients_file : File containing the gradient estimates
115- from the previous task.
114+ fisher_file : File containing the diagonal of the fisher information
115+ matrix estimated on the previous task.
116116 variables_files: File containing the variables learned
117117 on the previous task.
118118 """
@@ -124,23 +124,28 @@ def __init__(self,
124124 self .init_vars = tf .contrib .framework .load_checkpoint (variables_file )
125125 log ("EWC initial variables loaded." )
126126
127- log ("Loading gradient estimates from {}." .format (gradients_file ))
128- self .gradients = np .load (gradients_file )
127+ log ("Loading gradient estimates from {}." .format (fisher_file ))
128+ self .fisher = np .load (fisher_file )
129129 log ("Gradient estimates loaded." )
130130
131131 def value (self , variables : List [tf .Tensor ]) -> tf .Tensor :
132+ r"""Compute the value of the regularization term.
133+
134+ value = \sum_{i} (λ * F_{i} * (θ_{i} - θ_{i}^{*})^2)
135+
136+ where λ is the regularizer weight and F is the diagonal
137+ of the Fisher Information matrix.
138+ """
139+
132140 ewc_value = tf .constant (0.0 )
133141 for var in variables :
134142 init_var_name = var .name .split (":" )[0 ]
135- if (var .name in self .gradients .files
143+ if (var .name in self .fisher .files
136144 and self .init_vars .has_tensor (init_var_name )):
137145 init_var = tf .constant (
138146 self .init_vars .get_tensor (init_var_name ),
139147 name = "{}_init_value" .format (init_var_name ))
140- grad_squared = tf .constant (
141- np .square (self .gradients [var .name ]),
142- name = "{}_ewc_weight" .format (init_var_name ))
143148 ewc_value += tf .reduce_sum (tf .multiply (
144- grad_squared , tf .square (var - init_var )))
149+ self . fisher [ var . name ] , tf .square (var - init_var )))
145150
146151 return ewc_value
0 commit comments