33This module contains classes that can be used as a variable regularizers
44during training. All implementation should be derived from the Regularizer
55class.
6-
76"""
7+ from abc import ABCMeta , abstractmethod
88from typing import List
99
1010import numpy as np
1414from neuralmonkey .logging import log
1515
1616
17- class Regularizer :
18- """Base class for the regularizers."""
17+ class Regularizer (metaclass = ABCMeta ):
18+ """Base clas s for regularizers.
19+
20+ Regularizer objects are used to introduce additional loss terms to
21+ the trainerthus constraining the model variable during training. These
22+ loss terms have an adjustable weight allowing to set the ``importance''
23+ of the term.
24+ """
1925
2026 def __init__ (self ,
2127 name : str ,
@@ -24,10 +30,9 @@ def __init__(self,
2430
2531 Arguments:
2632 name: Regularizer name.
27- weight: Weight of the regularization term.
33+ weight: Weight of the regularization term (usually expressed
34+ as ``lambda'' in the literature).
2835 """
29- check_argument_types ()
30-
3136 self ._name = name
3237 self ._weight = weight
3338
@@ -39,34 +44,40 @@ def name(self) -> str:
3944 def weight (self ) -> float :
4045 return self ._weight
4146
42- def value (self , variables ) -> float :
47+ @abstractmethod
48+ def value (self , variables : List [tf .Tensor ]) -> tf .Tensor :
49+ """Compute the unweighted value of the regularization loss term.
50+
51+ Arguments:
52+ variables: List of the regularizable model variables.
53+ """
4354 raise NotImplementedError ("Abstract method" )
4455
4556
4657class L1Regularizer (Regularizer ):
4758 """L1 regularizer."""
4859
4960 def __init__ (self ,
50- name : str = "train_l1" ,
51- weight : float = 1.0e-8 ) -> None :
61+ name : str ,
62+ weight : float ) -> None :
5263 """Create the regularizer.
5364
5465 Arguments:
5566 name: Regularizer name.
56- weight: Weight of the regularization term.
67+ weight: Weight of the regularization term (default=1.0e-8 .
5768 """
5869 Regularizer .__init__ (self , name , weight )
5970
60- def value (self , variables : List [tf .Tensor ]) -> float :
71+ def value (self , variables : List [tf .Tensor ]) -> tf . Tensor :
6172 return sum (tf .reduce_sum (abs (v )) for v in variables )
6273
6374
6475class L2Regularizer (Regularizer ):
6576 """L2 regularizer."""
6677
6778 def __init__ (self ,
68- name : str = "train_l2" ,
69- weight : float = 1.0e-8 ) -> None :
79+ name : str ,
80+ weight : float ) -> None :
7081 """Create the regularizer.
7182
7283 Arguments:
@@ -75,7 +86,7 @@ def __init__(self,
7586 """
7687 Regularizer .__init__ (self , name , weight )
7788
78- def value (self , variables : List [tf .Tensor ]) -> float :
89+ def value (self , variables : List [tf .Tensor ]) -> tf . Tensor :
7990 return sum (tf .reduce_sum (v ** 2 ) for v in variables )
8091
8192
@@ -84,15 +95,18 @@ class EWCRegularizer(Regularizer):
8495
8596 Implements Elastic Weight Consolidation from the "Overcoming catastrophic
8697 forgetting in neural networks" paper.
98+ The regularizer applies separate regularization weight to each trainable
99+ variable based on how important the variable was for the previously
100+ learned task.
87101
88102 https://arxiv.org/pdf/1612.00796.pdf
89103 """
90104
91105 def __init__ (self ,
92- name : str = "train_ewc" ,
93- weight : float = 0. ,
94- gradients_file : str = None ,
95- variables_file : str = None ) -> None :
106+ name : str ,
107+ weight : float ,
108+ gradients_file : str ,
109+ variables_file : str ) -> None :
96110 """Create the regularizer.
97111
98112 Arguments:
@@ -104,36 +118,35 @@ def __init__(self,
104118 on the previous task.
105119 """
106120 check_argument_types ()
107-
108121 Regularizer .__init__ (self , name , weight )
109122
110- if gradients_file is None :
111- raise ValueError ("Missing gradients_file" )
112- if variables_file is None :
113- raise ValueError ("Missing variables_file" )
114-
115- log ("Loading initial variables for EWC from {}" .format (variables_file ))
123+ log ("Loading initial variables for EWC from "
124+ "{}." .format (variables_file ))
116125 self .init_vars = tf .contrib .framework .load_checkpoint (variables_file )
117- log ("EWC initial variables loaded" )
126+ log ("EWC initial variables loaded. " )
118127
119- log ("Loading gradient estimates from {}" .format (gradients_file ))
128+ log ("Loading gradient estimates from {}. " .format (gradients_file ))
120129 self .gradients = np .load (gradients_file )
121- log ("Gradient estimates loaded" )
130+ log ("Gradient estimates loaded. " )
122131
123- def value (self , variables : List [tf .Tensor ]) -> float :
132+ def value (self , variables : List [tf .Tensor ]) -> tf . Tensor :
124133 ewc_value = tf .constant (0.0 )
125134 for var in variables :
126- var_name = var .name .split (":" )[0 ]
135+ var_name = var .name
136+ init_var_name = var_name .split (":" )[0 ]
127137 if (var_name in self .gradients .files
128- and self .init_vars .has_tensor (var_name )):
129- init_var = self .init_vars .get_tensor (var_name )
130- gradient = tf .constant (
131- self .gradients [var_name ], name = "ewc_gradients" )
138+ and self .init_vars .has_tensor (init_var_name )):
139+ init_var = tf .constant (
140+ self .init_vars .get_tensor (init_var_name ),
141+ name = "{}_init_value" .format (init_var_name ))
142+ grad_squared = tf .constant (
143+ np .square (self .gradients [var_name ]),
144+ name = "{}_ewc_weight" .format (init_var_name ))
132145 ewc_value += tf .reduce_sum (tf .multiply (
133- tf . square ( gradient ) , tf .square (var - init_var )))
146+ grad_squared , tf .square (var - init_var )))
134147
135148 return ewc_value
136149
137150
138- L1 = L1Regularizer ()
139- L2 = L2Regularizer ()
151+ L1 = L1Regularizer (name = "train_l1" , weight = 1.0e-8 )
152+ L2 = L2Regularizer (name = "train_l2" , weight = 1.0e-8 )
0 commit comments