@@ -166,17 +166,17 @@ def add_positional_embedding_nd(x, max_length, name):
166166
167167
168168def embedding_to_padding (emb ):
169- """Input embeddings -> is_padding .
169+ """Calculates the padding mask based on which embeddings are all zero .
170170
171171 We have hacked symbol_modality to return all-zero embeddings for padding.
172172
173173 Args:
174174 emb: a Tensor with shape [..., depth].
175175 Returns:
176- a boolean Tensor with shape [...].
176+ a float Tensor with shape [...].
177177 """
178178 emb_sum = tf .reduce_sum (tf .abs (emb ), axis = - 1 )
179- return tf .equal (emb_sum , 0.0 )
179+ return tf .to_float ( tf . equal (emb_sum , 0.0 ) )
180180
181181
182182def attention_bias_lower_triangle (length ):
@@ -197,13 +197,13 @@ def attention_bias_ignore_padding(memory_padding):
197197 """Create an bias tensor to be added to attention logits.
198198
199199 Args:
200- memory_padding: a boolean `Tensor` with shape [batch, memory_length].
200+ memory_padding: a float `Tensor` with shape [batch, memory_length].
201201
202202 Returns:
203203 a `Tensor` with shape [batch, 1, 1, memory_length].
204204 """
205- ret = tf . to_float ( memory_padding ) * - 1e9
206- return tf .expand_dims (tf .expand_dims (ret , 1 ), 1 )
205+ ret = memory_padding * - 1e9
206+ return tf .expand_dims (tf .expand_dims (ret , axis = 1 ), axis = 1 )
207207
208208
209209def attention_bias_proximal (length ):
@@ -523,8 +523,7 @@ def pad_l_and_r(x, pad_length):
523523 # [batch, heads, blocks, block_length, dim]
524524 k_new = tf .transpose (k_new , [2 , 3 , 0 , 1 , 4 ])
525525
526- attention_bias = tf .expand_dims (
527- tf .to_float (embedding_to_padding (k_new )) * - 1e9 , axis = - 2 )
526+ attention_bias = tf .expand_dims (embedding_to_padding (k_new ) * - 1e9 , axis = - 2 )
528527
529528 v_t = tf .transpose (v , [2 , 0 , 1 , 3 ])
530529 v_new = tf .gather (v_t , gather_indices )
0 commit comments