@@ -344,23 +344,33 @@ def dot_product_attention(q,
344344 return tf .matmul (weights , v )
345345
346346
347- def masked_local_attention_1d (
348- q , k , v , block_length = 128 , name = None ):
349- """Attention to the source position and a neigborhood to the left of it.
347+ def local_attention_1d (q , k , v , bias = None ,
348+ block_length = 128 , look_right = True , use_whole_block = False ,
349+ truncate_bias = True , name = None ):
350+ """Attention to the source position and a neigborhood around it.
350351
351- The sequence is divided into blocks of length block_size.
352- Attention for a given query position can only see memory positions
353- less than or equal to the query position, in the corresponding block
354- and the previous block.
352+ The sequence is divided into blocks of length block_size. Attention for a
353+ given query position can only see memory positions within a certain number
354+ of positions before and behind it.
355355
356- If mask_right is True, then a target position cannot see greater source
356+ If look_right is True then each query will attend to block_length//2
357+ positions either side, otherwise it will attend to block_length previous
357358 positions.
358359
360+ If use_whole_block is True then no mask will be applied to the local blocks
361+ meaning the full blocks are used (if look_right is True then the elements to
362+ the right of the current position are still masked out). This allows use to
363+ attend to more elements without additional overhead, but means we have
364+ inconsistent window positions and sizes.
365+
359366 Args:
360- q: a Tensor with shape [batch, heads, length, depth_k]
361- k: a Tensor with shape [batch, heads, length, depth_k]
362- v: a Tensor with shape [batch, heads, length, depth_v]
367+ q: a Tensor with shape [batch, heads, length_q, depth_k]
368+ k: a Tensor with shape [batch, heads, length_kv, depth_k]
369+ v: a Tensor with shape [batch, heads, length_kv, depth_v]
370+ bias: Not currently used [batch, heads, length_q, length_k]
363371 block_length: an integer
372+ look_right: a bool
373+ use_whole_block: a bool
364374 name: an optional string
365375
366376 Returns:
@@ -372,146 +382,110 @@ def masked_local_attention_1d(
372382 batch = tf .shape (q )[0 ]
373383 heads = tf .shape (q )[1 ]
374384 length = tf .shape (q )[2 ]
375- # If (length < 2 * block_length), then we use only one block.
376- block_length = tf .where (tf .less (length , block_length * 2 ),
377- length , block_length )
378385 depth_k = tf .shape (q )[3 ]
379386 depth_v = tf .shape (v )[3 ]
387+
380388 original_length = length
389+
390+ #Pad to desired length
391+ #If (length < block_length), then we use only one block.
392+ block_length = tf .where (tf .less (length , block_length ),
393+ length , block_length )
381394 padding_size = tf .mod (- length , block_length )
382395 length += padding_size
383- padding = [[0 , 0 ], [0 , 0 ], [0 , padding_size ], [0 , 0 ]]
384- q = tf .pad (q , padding )
385- k = tf .pad (k , padding )
386- v = tf .pad (v , padding )
387396 num_blocks = tf .div (length , block_length )
388397
389- # compute attention for the first query block.
390- first_q = tf .slice (q , [0 , 0 , 0 , 0 ], [- 1 , - 1 , block_length , - 1 ])
391- first_k = tf .slice (k , [0 , 0 , 0 , 0 ], [- 1 , - 1 , block_length , - 1 ])
392- first_v = tf .slice (v , [0 , 0 , 0 , 0 ], [- 1 , - 1 , block_length , - 1 ])
393- first_output = dot_product_attention (
394- first_q , first_k , first_v , attention_bias_lower_triangle (block_length ),
395- name = "fist_block" )
398+ padding = [[0 , 0 ], [0 , 0 ], [0 , padding_size ], [0 , 0 ]]
399+ q = tf .pad (q , padding )
396400
397- # compute attention for all subsequent query blocks.
401+ if not look_right :
402+ #Add extra padding so we son't have to do an initial query
403+ extra_padding = [[0 , 0 ], [0 , 0 ], [block_length , padding_size ], [0 , 0 ]]
404+ bp = [[0 , 0 ], [0 , 0 ], [0 , padding_size ], [block_length , padding_size ]]
405+ else :
406+ #We shift everything over by half a block so query is in centre
407+ pad_right = block_length // 2
408+ pad_left = block_length - pad_right
409+ extra_padding = [[0 , 0 ], [0 , 0 ],
410+ [pad_left , padding_size + pad_right ], [0 , 0 ]]
411+ bp = [[0 , 0 ], [0 , 0 ],
412+ [0 , padding_size ], [pad_left , padding_size + pad_right ]]
413+ k = tf .pad (k , extra_padding )
414+ v = tf .pad (v , extra_padding )
415+
416+ # Reshape into blocks
398417 q = tf .reshape (q , [batch , heads , num_blocks , block_length , depth_k ])
399- k = tf .reshape (k , [batch , heads , num_blocks , block_length , depth_k ])
400- v = tf .reshape (v , [batch , heads , num_blocks , block_length , depth_v ])
418+ k = tf .reshape (k , [batch , heads , num_blocks + 1 , block_length , depth_k ])
419+ v = tf .reshape (v , [batch , heads , num_blocks + 1 , block_length , depth_v ])
401420
421+ # Get local blocks by slicing
402422 def local (x ):
403423 """Create a local version of the keys or values."""
404424 prev_block = tf .slice (
405- x , [0 , 0 , 0 , 0 , 0 ], [- 1 , - 1 , num_blocks - 1 , - 1 , - 1 ])
425+ x , [0 , 0 , 0 , 0 , 0 ], [- 1 , - 1 , num_blocks , - 1 , - 1 ])
406426 cur_block = tf .slice (
407427 x , [0 , 0 , 1 , 0 , 0 ], [- 1 , - 1 , - 1 , - 1 , - 1 ])
408428 return tf .concat ([prev_block , cur_block ], 3 )
409429 local_k = local (k )
410430 local_v = local (v )
411- tail_q = tf .slice (q , [0 , 0 , 1 , 0 , 0 ], [- 1 , - 1 , - 1 , - 1 , - 1 ])
412-
413431 local_length = tf .shape (local_k )[3 ]
414432
415- # [batch, heads, num_blocks - 1, block_length, local_length]
416- attention = tf .matmul (tail_q , local_k , transpose_b = True )
417-
418- # make sure source_pos <= target_pos
419- good_part = tf .matrix_band_part (
420- tf .ones ([block_length , local_length ]), - 1 , tf .to_int64 (block_length ))
421- mask = (1.0 - good_part ) * - 1e9
422- attention += tf .reshape (mask , [1 , 1 , 1 , block_length , local_length ])
433+ # [batch, heads, num_blocks, block_length, local_length]
434+ attention = tf .matmul (q , local_k , transpose_b = True )
435+
436+ # Apply bias (N.B: This is not currently working)
437+ if bias is not None :
438+ with tf .name_scope ('bias' ):
439+ b_batch = tf .shape (bias )[0 ]
440+ b_heads = tf .shape (bias )[1 ]
441+ bias_ = bias
442+ #bias = 1.0 + tf.clip_by_value(bias, -1.0, 1.0)
443+ if truncate_bias :
444+ # Use only the query dimension
445+ bias = tf .expand_dims (bias [:,:,:,0 ], 2 )
446+ bias = tf .pad (bias , extra_padding , name = 'bias_pad_b' )# 17, 5, 3
447+ bias = tf .reshape (bias ,
448+ [b_batch , b_heads , 1 , num_blocks + 1 , block_length ],
449+ name = 'divide_blocks' )
450+ local_b = tf .reshape (local (bias ),
451+ [b_batch , b_heads , num_blocks , 1 , - 1 ], name = 'reshape_local' )
452+ else :
453+ bias = tf .pad (bias , bp , name = 'pad' )
454+ bias = tf .reshape (bias ,
455+ [b_batch , b_heads , num_blocks , block_length ,
456+ num_blocks + 1 , block_length ], name = 'divide_blocks' )
457+ bias = tf .transpose (bias , [4 ,2 ,0 ,1 ,3 ,5 ])
458+ bias = tf .reshape (bias ,
459+ [num_blocks * (num_blocks + 1 ), b_batch , b_heads ,
460+ block_length , block_length ], name = 'combine' )
461+ indices = (num_blocks + 1 )* tf .range (num_blocks )
462+ prev_block = tf .gather (bias , indices )
463+ cur_block = tf .gather (bias , indices + num_blocks )
464+ local_b = tf .concat ([prev_block , cur_block ], 4 )
465+ local_b = tf .transpose (local_b , [1 ,2 ,0 ,3 ,4 ])
466+ return l - local_b
467+ attention += local_b
468+
423469 attention = tf .nn .softmax (attention )
424- # TODO(noam): figure out how to show a summary for the remaining blocks.
425- # The naive way currently causes errors due to empty tensors.
426- # output: [batch, heads, num_blocks-1, block_length, depth_v]
427- output = tf .matmul ( attention , local_v )
428- output = tf .reshape ( output , [ batch , heads , - 1 , depth_v ] )
429- output = tf . concat ([ first_output , output ], axis = 2 )
430- output = tf .slice ( output , [ 0 , 0 , 0 , 0 ], [ - 1 , - 1 , original_length , - 1 ])
431- output . set_shape ( v_shape )
432- return output
433-
470+
471+ # Get local mask
472+ if not use_whole_block :
473+ good_part = tf .matrix_band_part (
474+ tf .ones ([ block_length , local_length ]), 0 , tf . to_int64 ( block_length ) )
475+ elif not look_right :
476+ good_part = tf .matrix_band_part (
477+ tf . ones ([ block_length , local_length ]), - 1 , tf . to_int64 ( block_length ) )
478+ else :
479+ good_part = tf . ones ([ block_length , local_length ])
434480
435- def unmasked_local_attention_1d (q , k , v , block_length = 128 , filter_width = 100 ,
436- name = None ):
437- """strided block local self-attention.
481+ #good_part = tf.cast(good_part, tf.float64)
482+ attention *= tf .reshape (good_part , [1 , 1 , 1 , block_length , local_length ])
438483
439- Args:
440- q: a Tensor with shape [batch, heads, length, depth_k]
441- k: a Tensor with shape [batch, heads, length, depth_k]
442- v: a Tensor with shape [batch, heads, length, depth_v]
443- block_length: an integer
444- filter_width: an integer indicating how much to look left.
445- name: an optional string
484+
485+ output = tf .matmul (attention , local_v )
486+ output = tf .reshape (output , [batch , heads , - 1 , depth_v ])
446487
447- Returns:
448- a Tensor of shape [batch, heads, length, depth_v]
449- """
450- with tf .variable_scope (name , default_name = "local_self_attention_1d" ,
451- values = [q , k , v ]):
452- v_shape = v .get_shape ()
453- depth_v = tf .shape (v )[3 ]
454- batch_size = tf .shape (q )[0 ]
455- num_heads = tf .shape (q )[1 ]
456- original_length = tf .shape (q )[2 ]
457- # making sure q is a multiple of d
458- def pad_to_multiple (x , pad_length ):
459- x_length = tf .shape (x )[2 ]
460- return tf .pad (x , [[0 , 0 ], [0 , 0 ], [0 , - x_length % pad_length ], [0 , 0 ]])
461- def pad_l_and_r (x , pad_length ):
462- return tf .pad (x , [[0 , 0 ], [0 , 0 ], [pad_length , pad_length ], [0 , 0 ]])
463- q = pad_to_multiple (q , block_length )
464- k = pad_to_multiple (k , block_length )
465- v = pad_to_multiple (v , block_length )
466-
467- # Setting up q blocks
468- new_q_shape = tf .shape (q )
469- # Setting up q blocks
470- q = tf .reshape (q , [new_q_shape [0 ], new_q_shape [1 ],
471- new_q_shape [2 ]// block_length ,
472- block_length , new_q_shape [3 ]])
473-
474- # Setting up k and v values
475- k = pad_l_and_r (k , filter_width )
476- v = pad_l_and_r (v , filter_width )
477-
478- length = tf .shape (k )[2 ]
479- full_filter_width = block_length + 2 * filter_width
480- # getting gather indices
481- indices = tf .range (0 , length , delta = 1 , name = "index_range" )
482- # making indices [1, length, 1] to appy convs
483- indices = tf .reshape (indices , [1 , - 1 , 1 ])
484- kernel = tf .expand_dims (tf .eye (full_filter_width ), axis = 1 )
485- gather_indices = tf .nn .conv1d (
486- tf .cast (indices , tf .float32 ),
487- kernel ,
488- block_length ,
489- padding = "VALID" ,
490- name = "gather_conv" )
491-
492- gather_indices = tf .squeeze (tf .cast (gather_indices , tf .int32 ), axis = 0 )
493-
494- # [length, batch, heads, dim]
495- k_t = tf .transpose (k , [2 , 0 , 1 , 3 ])
496- k_new = tf .gather (k_t , gather_indices )
497-
498- # [batch, heads, blocks, block_length, dim]
499- k_new = tf .transpose (k_new , [2 , 3 , 0 , 1 , 4 ])
500-
501- attention_bias = tf .expand_dims (
502- tf .to_float (embedding_to_padding (k_new )) * - 1e9 , axis = - 2 )
503-
504- v_t = tf .transpose (v , [2 , 0 , 1 , 3 ])
505- v_new = tf .gather (v_t , gather_indices )
506- v_new = tf .transpose (v_new , [2 , 3 , 0 , 1 , 4 ])
507-
508- logits = tf .matmul (q , k_new , transpose_b = True )
509-
510- attention = tf .nn .softmax (logits + attention_bias )
511- output = tf .matmul (attention , v_new )
512-
513- output = tf .reshape (output , [batch_size , num_heads , - 1 , depth_v ])
514- # Remove the padding if introduced
488+ # Remove added padding
515489 output = tf .slice (output , [0 , 0 , 0 , 0 ], [- 1 , - 1 , original_length , - 1 ])
516490 output .set_shape (v_shape )
517491 return output
@@ -542,8 +516,8 @@ def multihead_attention(query_antecedent,
542516 dropout_rate: a floating point number
543517 image_shapes: optional tuple of integer scalars.
544518 see comments for attention_image_summary()
545- attention_type: a string, either "dot_product" or "local_mask_right " or
546- "local_unmasked "
519+ attention_type: a string, either "dot_product" or "local " or
520+ "local_mask_right "
547521 block_length: an integer - relevant for "local_mask_right"
548522 name: an optional string
549523
@@ -592,11 +566,12 @@ def multihead_attention(query_antecedent,
592566 if attention_type == "dot_product" :
593567 x = dot_product_attention (
594568 q , k , v , bias , dropout_rate , image_shapes )
595- elif attention_type == "local_mask_right " :
596- x = masked_local_attention_1d (q , k , v , block_length = block_length )
569+ elif attention_type == "local " :
570+ x = local_attention_1d (q , k , v , block_length = block_length )
597571 else :
598- assert attention_type == "local_unmasked"
599- x = unmasked_local_attention_1d (q , k , v , block_length = block_length )
572+ assert attention_type == "local_mask_right"
573+ x = local_attention_1d (
574+ q , k , v , block_length = block_length , look_right = False )
600575 x = combine_heads (x )
601576 x = common_layers .conv1d (x , output_depth , 1 , name = "output_transform" )
602577 return x
0 commit comments