@@ -29,7 +29,7 @@ def TransformerEncoder(mode='train', # pylint: disable=invalid-name
2929 feature_depth = 512 ,
3030 feedforward_depth = 2048 ,
3131 num_heads = 8 ,
32- dropout = 0.9 ):
32+ dropout = 0.1 ):
3333 """Transformer Encoder Stack.
3434
3535 Args:
@@ -38,20 +38,22 @@ def TransformerEncoder(mode='train', # pylint: disable=invalid-name
3838 feature_depth: int: depth of embedding
3939 feedforward_depth: int: depth of feed-forward layer
4040 num_heads: int: number of attention heads
41- dropout: float: dropout rate - Stax follows TF's KEEP probability convention
41+ dropout: float: dropout rate (how much to drop out; note that stax follows
42+ Tensorflow's keep_rate convention, so we use 1 - dropout in calls below)
4243
4344 Returns:
4445 A staxlayer for implementing a raw Transformer encoder stack. No embedding
4546 or positional signals are added by this layer.
4647 """
48+ keep_rate = 1.0 - dropout
4749 # Multi-headed Attention and Feed-forward layers
4850 multi_attention = stax .MultiHeadedAttention (
49- feature_depth , num_heads = num_heads , dropout = dropout , mode = mode )
51+ feature_depth , num_heads = num_heads , dropout = keep_rate , mode = mode )
5052
5153 feed_forward = stax .serial (
5254 stax .Dense (feedforward_depth , W_init = stax .xavier_uniform ()),
5355 stax .Relu ,
54- stax .Dropout (dropout , mode = mode ),
56+ stax .Dropout (keep_rate , mode = mode ),
5557 stax .Dense (feature_depth , W_init = stax .xavier_uniform ())
5658 )
5759
@@ -74,11 +76,11 @@ def encoder(embedded_source, source_mask):
7476 stax .Identity , # value
7577 source_mask ), # attention mask
7678 multi_attention ,
77- stax .Dropout (dropout , mode = mode )),
79+ stax .Dropout (keep_rate , mode = mode )),
7880 # feed-forward
7981 stax .residual (stax .LayerNorm (feature_depth ),
8082 feed_forward ,
81- stax .Dropout (dropout , mode = mode ))
83+ stax .Dropout (keep_rate , mode = mode ))
8284 )
8385 return stax .serial (
8486 embedded_source ,
@@ -95,8 +97,8 @@ def TransformerLM(vocab_size, # pylint: disable=invalid-name
9597 feature_depth = 512 ,
9698 feedforward_depth = 2048 ,
9799 num_heads = 8 ,
98- dropout = 0.9 ,
99- max_len = 256 ):
100+ dropout = 0.1 ,
101+ max_len = 512 ):
100102 """Transformer language model (only uses the decoder part of Transformer).
101103
102104 Args:
@@ -106,20 +108,21 @@ def TransformerLM(vocab_size, # pylint: disable=invalid-name
106108 feature_depth: int: depth of embedding
107109 feedforward_depth: int: depth of feed-forward layer
108110 num_heads: int: number of attention heads
109- dropout: float: dropout rate - Stax follows TF's KEEP probability convention
111+ dropout: float: dropout rate (how much to drop out)
110112 max_len: int: maximum symbol length for positional encoding
111113
112114 Returns:
113115 init and apply.
114116 """
117+ keep_rate = 1.0 - dropout
115118 # Multi-headed Attention and Feed-forward layers
116119 multi_attention = stax .MultiHeadedAttention (
117- feature_depth , num_heads = num_heads , dropout = dropout , mode = mode )
120+ feature_depth , num_heads = num_heads , dropout = keep_rate , mode = mode )
118121
119122 feed_forward = stax .serial (
120123 stax .Dense (feedforward_depth , W_init = stax .xavier_uniform ()),
121124 stax .Relu ,
122- stax .Dropout (dropout , mode = mode ),
125+ stax .Dropout (keep_rate , mode = mode ),
123126 stax .Dense (feature_depth , W_init = stax .xavier_uniform ())
124127 )
125128
@@ -132,18 +135,18 @@ def TransformerLM(vocab_size, # pylint: disable=invalid-name
132135 stax .Identity , # value
133136 stax .CausalMask (axis = - 2 )), # attention mask
134137 multi_attention ,
135- stax .Dropout (dropout , mode = mode )),
138+ stax .Dropout (keep_rate , mode = mode )),
136139 # feed-forward
137140 stax .residual (stax .LayerNorm (feature_depth ),
138141 feed_forward ,
139- stax .Dropout (dropout , mode = mode ))
142+ stax .Dropout (keep_rate , mode = mode ))
140143 )
141144
142145 return stax .serial (
143146 stax .ShiftRight (),
144147 stax .Embedding (feature_depth , vocab_size ),
145148 stax .PositionalEncoding (feature_depth , max_len = max_len ),
146- stax .Dropout (dropout , mode = mode ),
149+ stax .Dropout (keep_rate , mode = mode ),
147150 stax .repeat (decoder_layer , num_layers ),
148151 stax .LayerNorm (feature_depth ),
149152 stax .Dense (vocab_size , W_init = stax .xavier_uniform ()),
@@ -158,7 +161,7 @@ def Transformer(source_vocab_size, # pylint: disable=invalid-name
158161 feature_depth = 512 ,
159162 feedforward_depth = 2048 ,
160163 num_heads = 8 ,
161- dropout = 0.9 ,
164+ dropout = 0.1 ,
162165 shared_embedding = True ,
163166 max_len = 200 ,
164167 return_evals = False ):
@@ -172,7 +175,7 @@ def Transformer(source_vocab_size, # pylint: disable=invalid-name
172175 feature_depth: int: depth of embedding
173176 feedforward_depth: int: depth of feed-forward layer
174177 num_heads: int: number of attention heads
175- dropout: float: dropout rate - Stax follows TF's KEEP probability convention
178+ dropout: float: dropout rate (how much to drop out)
176179 shared_embedding: bool: specify whether source/target embeddings are tied.
177180 max_len: int: maximum symbol length for positional encoding
178181 return_evals: bool: whether to generate decode-time evaluation functions
@@ -182,11 +185,11 @@ def Transformer(source_vocab_size, # pylint: disable=invalid-name
182185 the 'evals' functions that itself returns a namedtuple containing evaluation
183186 functions for the trained encoder, decoder, and generator substax.
184187 """
185-
188+ keep_rate = 1.0 - dropout
186189 # Input embedding and positional encoding
187190 inject_position = stax .serial (
188191 stax .PositionalEncoding (feature_depth , max_len = max_len ),
189- stax .Dropout (dropout , mode = mode )
192+ stax .Dropout (keep_rate , mode = mode )
190193 )
191194 if shared_embedding :
192195 assert source_vocab_size == target_vocab_size
@@ -202,12 +205,12 @@ def Transformer(source_vocab_size, # pylint: disable=invalid-name
202205
203206 # Multi-headed Attention and Feed-forward layers
204207 multi_attention = stax .MultiHeadedAttention (
205- feature_depth , num_heads = num_heads , dropout = dropout , mode = mode )
208+ feature_depth , num_heads = num_heads , dropout = keep_rate , mode = mode )
206209
207210 feed_forward = stax .serial (
208211 stax .Dense (feedforward_depth , W_init = stax .xavier_uniform ()),
209212 stax .Relu ,
210- stax .Dropout (dropout , mode = mode ),
213+ stax .Dropout (keep_rate , mode = mode ),
211214 stax .Dense (feature_depth , W_init = stax .xavier_uniform ())
212215 )
213216
@@ -231,11 +234,11 @@ def encoder(source, source_mask):
231234 stax .Identity , # value
232235 source_mask ), # attention mask
233236 multi_attention ,
234- stax .Dropout (dropout , mode = mode )),
237+ stax .Dropout (keep_rate , mode = mode )),
235238 # feed-forward
236239 stax .residual (stax .LayerNorm (feature_depth ),
237240 feed_forward ,
238- stax .Dropout (dropout , mode = mode ))
241+ stax .Dropout (keep_rate , mode = mode ))
239242 )
240243 return stax .serial (
241244 source ,
@@ -266,19 +269,19 @@ def decoder(memory, target, target_mask, memory_mask):
266269 stax .Identity , # value
267270 target_mask ), # attention mask
268271 multi_attention ,
269- stax .Dropout (dropout , mode = mode )),
272+ stax .Dropout (keep_rate , mode = mode )),
270273 # target attends to encoded source
271274 stax .residual (stax .LayerNorm (feature_depth ),
272275 stax .multiplex (stax .Identity , # query
273276 memory , # key
274277 memory , # value
275278 memory_mask ), # attention mask
276279 multi_attention ,
277- stax .Dropout (dropout , mode = mode )),
280+ stax .Dropout (keep_rate , mode = mode )),
278281 # feed-forward
279282 stax .residual (stax .LayerNorm (feature_depth ),
280283 feed_forward ,
281- stax .Dropout (dropout , mode = mode ))
284+ stax .Dropout (keep_rate , mode = mode ))
282285 )
283286 return stax .serial (
284287 target ,
0 commit comments