@@ -1336,6 +1336,133 @@ def transformer_prepare_decoder(targets, hparams, features=None):
13361336 return (decoder_input , decoder_self_attention_bias )
13371337
13381338
1339+ def transformer_decoder_layer (decoder_input ,
1340+ decoder_self_attention_bias ,
1341+ layer_idx ,
1342+ hparams ,
1343+ encoder_output = None ,
1344+ encoder_decoder_attention_bias = None ,
1345+ cache = None ,
1346+ decode_loop_step = None ,
1347+ nonpadding = None ,
1348+ save_weights_to = None ,
1349+ make_image_summary = False ,
1350+ losses = None ,
1351+ layer_collection = None ,
1352+ recurrent_memory_by_layer = None ,
1353+ chunk_number = None ):
1354+ """A single transformer decoder layer."""
1355+ x = decoder_input
1356+ layer = layer_idx
1357+ layer_name = "layer_%d" % layer
1358+ layer_cache = cache [layer_name ] if cache is not None else None
1359+
1360+ attention_dropout_broadcast_dims = (
1361+ common_layers .comma_separated_string_to_integer_list (
1362+ getattr (hparams , "attention_dropout_broadcast_dims" , "" )))
1363+
1364+ if recurrent_memory_by_layer is not None :
1365+ recurrent_memory = recurrent_memory_by_layer [layer_name ]
1366+ else :
1367+ recurrent_memory = None
1368+
1369+ if layer < hparams .get ("num_area_layers" , 0 ):
1370+ max_area_width = hparams .get ("max_area_width" , 1 )
1371+ max_area_height = hparams .get ("max_area_height" , 1 )
1372+ memory_height = hparams .get ("max_area_height" , 1 )
1373+ else :
1374+ max_area_width = 1
1375+ max_area_height = 1
1376+ memory_height = 1
1377+ with tf .variable_scope (layer_name ):
1378+ with tf .variable_scope ("self_attention" ):
1379+ y = common_attention .multihead_attention (
1380+ common_layers .layer_preprocess (
1381+ x , hparams , layer_collection = layer_collection ),
1382+ None ,
1383+ decoder_self_attention_bias ,
1384+ hparams .attention_key_channels or hparams .hidden_size ,
1385+ hparams .attention_value_channels or hparams .hidden_size ,
1386+ hparams .hidden_size ,
1387+ hparams .num_heads ,
1388+ hparams .attention_dropout ,
1389+ attention_type = hparams .self_attention_type ,
1390+ max_relative_position = hparams .max_relative_position ,
1391+ heads_share_relative_embedding = (
1392+ hparams .heads_share_relative_embedding ),
1393+ add_relative_to_values = hparams .add_relative_to_values ,
1394+ save_weights_to = save_weights_to ,
1395+ cache = layer_cache ,
1396+ make_image_summary = make_image_summary ,
1397+ dropout_broadcast_dims = attention_dropout_broadcast_dims ,
1398+ max_length = hparams .get ("max_length" ),
1399+ decode_loop_step = decode_loop_step ,
1400+ vars_3d = hparams .get ("attention_variables_3d" ),
1401+ activation_dtype = hparams .get ("activation_dtype" , "float32" ),
1402+ weight_dtype = hparams .get ("weight_dtype" , "float32" ),
1403+ layer_collection = layer_collection ,
1404+ recurrent_memory = recurrent_memory ,
1405+ chunk_number = chunk_number ,
1406+ hard_attention_k = hparams .get ("hard_attention_k" , 0 ),
1407+ max_area_width = max_area_width ,
1408+ max_area_height = max_area_height ,
1409+ memory_height = memory_height ,
1410+ area_key_mode = hparams .get ("area_key_mode" , "none" ),
1411+ area_value_mode = hparams .get ("area_value_mode" , "none" ),
1412+ training = (hparams .get (
1413+ "mode" ,
1414+ tf .estimator .ModeKeys .TRAIN ) == tf .estimator .ModeKeys .TRAIN ))
1415+ x = common_layers .layer_postprocess (x , y , hparams )
1416+ if encoder_output is not None :
1417+ with tf .variable_scope ("encdec_attention" ):
1418+ y = common_attention .multihead_attention (
1419+ common_layers .layer_preprocess (
1420+ x , hparams , layer_collection = layer_collection ),
1421+ encoder_output ,
1422+ encoder_decoder_attention_bias ,
1423+ hparams .attention_key_channels or hparams .hidden_size ,
1424+ hparams .attention_value_channels or hparams .hidden_size ,
1425+ hparams .hidden_size ,
1426+ hparams .num_heads ,
1427+ hparams .attention_dropout ,
1428+ max_relative_position = hparams .max_relative_position ,
1429+ heads_share_relative_embedding = (
1430+ hparams .heads_share_relative_embedding ),
1431+ add_relative_to_values = hparams .add_relative_to_values ,
1432+ save_weights_to = save_weights_to ,
1433+ cache = layer_cache ,
1434+ make_image_summary = make_image_summary ,
1435+ dropout_broadcast_dims = attention_dropout_broadcast_dims ,
1436+ max_length = hparams .get ("max_length" ),
1437+ vars_3d = hparams .get ("attention_variables_3d" ),
1438+ activation_dtype = hparams .get ("activation_dtype" , "float32" ),
1439+ weight_dtype = hparams .get ("weight_dtype" , "float32" ),
1440+ layer_collection = layer_collection ,
1441+ hard_attention_k = hparams .get ("hard_attention_k" , 0 ),
1442+ max_area_width = max_area_width ,
1443+ max_area_height = max_area_height ,
1444+ memory_height = memory_height ,
1445+ area_key_mode = hparams .get ("area_key_mode" , "none" ),
1446+ area_value_mode = hparams .get ("area_value_mode" , "none" ),
1447+ training = (hparams .get (
1448+ "mode" ,
1449+ tf .estimator .ModeKeys .TRAIN ) == tf .estimator .ModeKeys .TRAIN ))
1450+ x = common_layers .layer_postprocess (x , y , hparams )
1451+ with tf .variable_scope ("ffn" ):
1452+ y = transformer_ffn_layer (
1453+ common_layers .layer_preprocess (
1454+ x , hparams , layer_collection = layer_collection ),
1455+ hparams ,
1456+ conv_padding = "LEFT" ,
1457+ nonpadding_mask = nonpadding ,
1458+ losses = losses ,
1459+ cache = layer_cache ,
1460+ decode_loop_step = decode_loop_step ,
1461+ layer_collection = layer_collection )
1462+ x = common_layers .layer_postprocess (x , y , hparams )
1463+ return x
1464+
1465+
13391466def transformer_decoder (decoder_input ,
13401467 encoder_output ,
13411468 decoder_self_attention_bias ,
@@ -1350,8 +1477,7 @@ def transformer_decoder(decoder_input,
13501477 losses = None ,
13511478 layer_collection = None ,
13521479 recurrent_memory_by_layer = None ,
1353- chunk_number = None ,
1354- ):
1480+ chunk_number = None ):
13551481 """A stack of transformer layers.
13561482
13571483 Args:
@@ -1377,8 +1503,8 @@ def transformer_decoder(decoder_input,
13771503 key created from the variable scope (including name).
13781504 make_image_summary: Whether to make an attention image summary.
13791505 losses: optional list onto which to append extra training losses
1380- layer_collection: A tensorflow_kfac.LayerCollection. Only used by the
1381- KFAC optimizer. Default is None.
1506+ layer_collection: A tensorflow_kfac.LayerCollection. Only used by the KFAC
1507+ optimizer. Default is None.
13821508 recurrent_memory_by_layer: Optional dict, mapping layer names to instances
13831509 of transformer_memory.RecurrentMemory. Default is None.
13841510 chunk_number: an optional integer Tensor with shape [batch] used to operate
@@ -1388,9 +1514,6 @@ def transformer_decoder(decoder_input,
13881514 y: a Tensors
13891515 """
13901516 x = decoder_input
1391- attention_dropout_broadcast_dims = (
1392- common_layers .comma_separated_string_to_integer_list (
1393- getattr (hparams , "attention_dropout_broadcast_dims" , "" )))
13941517
13951518 mlperf_log .transformer_print (
13961519 key = mlperf_log .MODEL_HP_NUM_HIDDEN_LAYERS ,
@@ -1410,106 +1533,26 @@ def transformer_decoder(decoder_input,
14101533 hparams = hparams )
14111534
14121535 with tf .variable_scope (name ):
1413- for layer in range (hparams .num_decoder_layers or hparams .num_hidden_layers ):
1414- layer_name = "layer_%d" % layer
1415- layer_cache = cache [layer_name ] if cache is not None else None
1416- if recurrent_memory_by_layer is not None :
1417- recurrent_memory = recurrent_memory_by_layer [layer_name ]
1418- else :
1419- recurrent_memory = None
1536+ for layer_idx in range (hparams .num_decoder_layers or
1537+ hparams .num_hidden_layers ):
1538+ x = transformer_decoder_layer (
1539+ x ,
1540+ decoder_self_attention_bias ,
1541+ layer_idx ,
1542+ hparams ,
1543+ encoder_decoder_attention_bias = encoder_decoder_attention_bias ,
1544+ encoder_output = encoder_output ,
1545+ cache = cache ,
1546+ decode_loop_step = decode_loop_step ,
1547+ nonpadding = nonpadding ,
1548+ save_weights_to = save_weights_to ,
1549+ make_image_summary = make_image_summary ,
1550+ losses = losses ,
1551+ layer_collection = layer_collection ,
1552+ recurrent_memory_by_layer = recurrent_memory_by_layer ,
1553+ chunk_number = chunk_number ,
1554+ )
14201555
1421- if layer < hparams .get ("num_area_layers" , 0 ):
1422- max_area_width = hparams .get ("max_area_width" , 1 )
1423- max_area_height = hparams .get ("max_area_height" , 1 )
1424- memory_height = hparams .get ("max_area_height" , 1 )
1425- else :
1426- max_area_width = 1
1427- max_area_height = 1
1428- memory_height = 1
1429- with tf .variable_scope (layer_name ):
1430- with tf .variable_scope ("self_attention" ):
1431- y = common_attention .multihead_attention (
1432- common_layers .layer_preprocess (
1433- x , hparams , layer_collection = layer_collection ),
1434- None ,
1435- decoder_self_attention_bias ,
1436- hparams .attention_key_channels or hparams .hidden_size ,
1437- hparams .attention_value_channels or hparams .hidden_size ,
1438- hparams .hidden_size ,
1439- hparams .num_heads ,
1440- hparams .attention_dropout ,
1441- attention_type = hparams .self_attention_type ,
1442- max_relative_position = hparams .max_relative_position ,
1443- heads_share_relative_embedding = (
1444- hparams .heads_share_relative_embedding ),
1445- add_relative_to_values = hparams .add_relative_to_values ,
1446- save_weights_to = save_weights_to ,
1447- cache = layer_cache ,
1448- make_image_summary = make_image_summary ,
1449- dropout_broadcast_dims = attention_dropout_broadcast_dims ,
1450- max_length = hparams .get ("max_length" ),
1451- decode_loop_step = decode_loop_step ,
1452- vars_3d = hparams .get ("attention_variables_3d" ),
1453- activation_dtype = hparams .get ("activation_dtype" , "float32" ),
1454- weight_dtype = hparams .get ("weight_dtype" , "float32" ),
1455- layer_collection = layer_collection ,
1456- recurrent_memory = recurrent_memory ,
1457- chunk_number = chunk_number ,
1458- hard_attention_k = hparams .get ("hard_attention_k" , 0 ),
1459- max_area_width = max_area_width ,
1460- max_area_height = max_area_height ,
1461- memory_height = memory_height ,
1462- area_key_mode = hparams .get ("area_key_mode" , "none" ),
1463- area_value_mode = hparams .get ("area_value_mode" , "none" ),
1464- training = (hparams .get ("mode" , tf .estimator .ModeKeys .TRAIN )
1465- == tf .estimator .ModeKeys .TRAIN ))
1466- x = common_layers .layer_postprocess (x , y , hparams )
1467- if encoder_output is not None :
1468- with tf .variable_scope ("encdec_attention" ):
1469- y = common_attention .multihead_attention (
1470- common_layers .layer_preprocess (
1471- x , hparams , layer_collection = layer_collection ),
1472- encoder_output ,
1473- encoder_decoder_attention_bias ,
1474- hparams .attention_key_channels or hparams .hidden_size ,
1475- hparams .attention_value_channels or hparams .hidden_size ,
1476- hparams .hidden_size ,
1477- hparams .num_heads ,
1478- hparams .attention_dropout ,
1479- max_relative_position = hparams .max_relative_position ,
1480- heads_share_relative_embedding = (
1481- hparams .heads_share_relative_embedding ),
1482- add_relative_to_values = hparams .add_relative_to_values ,
1483- save_weights_to = save_weights_to ,
1484- cache = layer_cache ,
1485- make_image_summary = make_image_summary ,
1486- dropout_broadcast_dims = attention_dropout_broadcast_dims ,
1487- max_length = hparams .get ("max_length" ),
1488- vars_3d = hparams .get ("attention_variables_3d" ),
1489- activation_dtype = hparams .get ("activation_dtype" , "float32" ),
1490- weight_dtype = hparams .get ("weight_dtype" , "float32" ),
1491- layer_collection = layer_collection ,
1492- hard_attention_k = hparams .get ("hard_attention_k" , 0 ),
1493- max_area_width = max_area_width ,
1494- max_area_height = max_area_height ,
1495- memory_height = memory_height ,
1496- area_key_mode = hparams .get ("area_key_mode" , "none" ),
1497- area_value_mode = hparams .get ("area_value_mode" , "none" ),
1498- training = (hparams .get ("mode" , tf .estimator .ModeKeys .TRAIN )
1499- == tf .estimator .ModeKeys .TRAIN ))
1500- x = common_layers .layer_postprocess (x , y , hparams )
1501- with tf .variable_scope ("ffn" ):
1502- y = transformer_ffn_layer (
1503- common_layers .layer_preprocess (
1504- x , hparams , layer_collection = layer_collection ),
1505- hparams ,
1506- conv_padding = "LEFT" ,
1507- nonpadding_mask = nonpadding ,
1508- losses = losses ,
1509- cache = layer_cache ,
1510- decode_loop_step = decode_loop_step ,
1511- layer_collection = layer_collection )
1512- x = common_layers .layer_postprocess (x , y , hparams )
15131556 # if normalization is done in layer_preprocess, then it should also be done
15141557 # on the output, since the output can grow very large, being the sum of
15151558 # a whole stack of unnormalized layer outputs.
0 commit comments