3131import tensorflow as tf
3232from tensorflow .python .util import nest
3333
34- # Track Tuple of state and attention values
35- AttentionTuple = collections .namedtuple ("AttentionTuple" , ("state" ,
36- "attention" ))
37-
38-
39- class ExternalAttentionCellWrapper (tf .contrib .rnn .RNNCell ):
40- """Wrapper for external attention states for an encoder-decoder setup."""
41-
42- def __init__ (self ,
43- cell ,
44- attn_states ,
45- attn_vec_size = None ,
46- input_size = None ,
47- state_is_tuple = True ,
48- reuse = None ):
49- """Create a cell with attention.
50-
51- Args:
52- cell: an RNNCell, an attention is added to it.
53- attn_states: External attention states typically the encoder output in the
54- form [batch_size, time steps, hidden size]
55- attn_vec_size: integer, the number of convolutional features calculated
56- on attention state and a size of the hidden layer built from
57- base cell state. Equal attn_size to by default.
58- input_size: integer, the size of a hidden linear layer,
59- built from inputs and attention. Derived from the input tensor
60- by default.
61- state_is_tuple: If True, accepted and returned states are n-tuples, where
62- `n = len(cells)`. Must be set to True else will raise an exception
63- concatenated along the column axis.
64- reuse: (optional) Python boolean describing whether to reuse variables
65- in an existing scope. If not `True`, and the existing scope already has
66- the given variables, an error is raised.
67- Raises:
68- TypeError: if cell is not an RNNCell.
69- ValueError: if the flag `state_is_tuple` is `False` or if shape of
70- `attn_states` is not 3 or if innermost dimension (hidden size) is None.
71- """
72- super (ExternalAttentionCellWrapper , self ).__init__ (_reuse = reuse )
73- if not state_is_tuple :
74- raise ValueError ("Only tuple state is supported" )
75-
76- self ._cell = cell
77- self ._input_size = input_size
78-
79- # Validate attn_states shape.
80- attn_shape = attn_states .get_shape ()
81- if not attn_shape or len (attn_shape ) != 3 :
82- raise ValueError ("attn_shape must be rank 3" )
83-
84- self ._attn_states = attn_states
85- self ._attn_size = attn_shape [2 ].value
86- if self ._attn_size is None :
87- raise ValueError ("Hidden size of attn_states cannot be None" )
88-
89- self ._attn_vec_size = attn_vec_size
90- if self ._attn_vec_size is None :
91- self ._attn_vec_size = self ._attn_size
92-
93- self ._reuse = reuse
94-
95- @property
96- def state_size (self ):
97- return AttentionTuple (self ._cell .state_size , self ._attn_size )
98-
99- @property
100- def output_size (self ):
101- return self ._attn_size
102-
103- def combine_state (self , previous_state ):
104- """Combines previous state (from encoder) with internal attention values.
105-
106- You must use this function to derive the initial state passed into
107- this cell as it expects a named tuple (AttentionTuple).
108-
109- Args:
110- previous_state: State from another block that will be fed into this cell;
111- Must have same structure as the state of the cell wrapped by this.
112- Returns:
113- Combined state (AttentionTuple).
114- """
115- batch_size = self ._attn_states .get_shape ()[0 ].value
116- if batch_size is None :
117- batch_size = tf .shape (self ._attn_states )[0 ]
118- zeroed_state = self .zero_state (batch_size , self ._attn_states .dtype )
119- return AttentionTuple (previous_state , zeroed_state .attention )
120-
121- def call (self , inputs , state ):
122- """Long short-term memory cell with attention (LSTMA)."""
123-
124- if not isinstance (state , AttentionTuple ):
125- raise TypeError ("State must be of type AttentionTuple" )
126-
127- state , attns = state
128- attn_states = self ._attn_states
129- attn_length = attn_states .get_shape ()[1 ].value
130- if attn_length is None :
131- attn_length = tf .shape (attn_states )[1 ]
132-
133- input_size = self ._input_size
134- if input_size is None :
135- input_size = inputs .get_shape ().as_list ()[1 ]
136- if attns is not None :
137- inputs = tf .layers .dense (tf .concat ([inputs , attns ], axis = 1 ), input_size )
138- lstm_output , new_state = self ._cell (inputs , state )
139-
140- new_state_cat = tf .concat (nest .flatten (new_state ), 1 )
141- new_attns = self ._attention (new_state_cat , attn_states , attn_length )
142-
143- with tf .variable_scope ("attn_output_projection" ):
144- output = tf .layers .dense (
145- tf .concat ([lstm_output , new_attns ], axis = 1 ), self ._attn_size )
146-
147- new_state = AttentionTuple (new_state , new_attns )
148-
149- return output , new_state
150-
151- def _attention (self , query , attn_states , attn_length ):
152- conv2d = tf .nn .conv2d
153- reduce_sum = tf .reduce_sum
154- softmax = tf .nn .softmax
155- tanh = tf .tanh
156-
157- with tf .variable_scope ("attention" ):
158- k = tf .get_variable ("attn_w" ,
159- [1 , 1 , self ._attn_size , self ._attn_vec_size ])
160- v = tf .get_variable ("attn_v" , [self ._attn_vec_size , 1 ])
161- hidden = tf .reshape (attn_states , [- 1 , attn_length , 1 , self ._attn_size ])
162- hidden_features = conv2d (hidden , k , [1 , 1 , 1 , 1 ], "SAME" )
163- y = tf .layers .dense (query , self ._attn_vec_size )
164- y = tf .reshape (y , [- 1 , 1 , 1 , self ._attn_vec_size ])
165- s = reduce_sum (v * tanh (hidden_features + y ), [2 , 3 ])
166- a = softmax (s )
167- d = reduce_sum (tf .reshape (a , [- 1 , attn_length , 1 , 1 ]) * hidden , [1 , 2 ])
168- new_attns = tf .reshape (d , [- 1 , self ._attn_size ])
169-
170- return new_attns
171-
17234
17335def lstm (inputs , hparams , train , name , initial_state = None ):
17436 """Run LSTM cell on inputs, assuming they are [batch x time x size]."""
@@ -189,7 +51,7 @@ def dropout_lstm_cell():
18951
19052
19153def lstm_attention_decoder (inputs , hparams , train , name , initial_state ,
192- attn_states ):
54+ encoder_outputs ):
19355 """Run LSTM cell with attention on inputs of shape [batch x time x size]."""
19456
19557 def dropout_lstm_cell ():
@@ -198,18 +60,36 @@ def dropout_lstm_cell():
19860 input_keep_prob = 1.0 - hparams .dropout * tf .to_float (train ))
19961
20062 layers = [dropout_lstm_cell () for _ in range (hparams .num_hidden_layers )]
201- cell = ExternalAttentionCellWrapper (
63+ AttentionMechanism = (tf .contrib .seq2seq .LuongAttention if hparams .attention_mechanism == "luong"
64+ else tf .contrib .seq2seq .BahdanauAttention )
65+ attention_mechanism = AttentionMechanism (hparams .hidden_size , encoder_outputs )
66+
67+ cell = tf .contrib .seq2seq .AttentionWrapper (
20268 tf .nn .rnn_cell .MultiRNNCell (layers ),
203- attn_states ,
204- attn_vec_size = hparams .attn_vec_size )
205- initial_state = cell .combine_state (initial_state )
69+ [attention_mechanism ]* hparams .num_heads ,
70+ attention_layer_size = [hparams .attention_layer_size ]* hparams .num_heads ,
71+ output_attention = (hparams .output_attention == 1 ))
72+
73+
74+ batch_size = inputs .get_shape ()[0 ].value
75+ if batch_size is None :
76+ batch_size = tf .shape (inputs )[0 ]
77+
78+ initial_state = cell .zero_state (batch_size , tf .float32 ).clone (cell_state = initial_state )
79+
20680 with tf .variable_scope (name ):
207- return tf .nn .dynamic_rnn (
81+ output , state = tf .nn .dynamic_rnn (
20882 cell ,
20983 inputs ,
21084 initial_state = initial_state ,
21185 dtype = tf .float32 ,
21286 time_major = False )
87+
88+ # For multi-head attention project output back to hidden size
89+ if hparams .output_attention == 1 and hparams .num_heads > 1 :
90+ output = tf .layers .dense (output , hparams .hidden_size )
91+
92+ return output , state
21393
21494
21595def lstm_seq2seq_internal (inputs , targets , hparams , train ):
@@ -273,14 +153,49 @@ def lstm_seq2seq():
273153 hparams .hidden_size = 128
274154 hparams .num_hidden_layers = 2
275155 hparams .initializer = "uniform_unit_scaling"
156+ hparams .initializer_gain = 1.0
157+ hparams .weight_decay = 0.0
158+
159+ return hparams
160+
161+ def lstm_attention_base ():
162+ """ Base attention params. """
163+ hparams = lstm_seq2seq ()
164+ hparams .add_hparam ("attention_layer_size" , hparams .hidden_size )
165+ hparams .add_hparam ("output_attention" , int (True ))
166+ hparams .add_hparam ("num_heads" , 1 )
276167 return hparams
277168
278169
170+ @registry .register_hparams
171+ def lstm_bahdanau_attention ():
172+ """hparams for LSTM with bahdanau attention."""
173+ hparams = lstm_attention_base ()
174+ hparams .add_hparam ("attention_mechanism" , "bahdanau" )
175+ return hparams
176+
177+ @registry .register_hparams
178+ def lstm_luong_attention ():
179+ """hparams for LSTM with luong attention."""
180+ hparams = lstm_attention_base ()
181+ hparams .add_hparam ("attention_mechanism" , "luong" )
182+ return hparams
183+
279184@registry .register_hparams
280185def lstm_attention ():
281- """hparams for LSTM with attention. """
282- hparams = lstm_seq2seq ()
186+ """ For backwards compatibility, Defaults to bahdanau """
187+ return lstm_bahdanau_attention ()
283188
284- # Attention
285- hparams .add_hparam ("attn_vec_size" , hparams .hidden_size )
189+ @registry .register_hparams
190+ def lstm_bahdanau_attention_multi ():
191+ """ Multi-head Luong attention """
192+ hparams = lstm_bahdanau_attention ()
193+ hparams .num_heads = 4
286194 return hparams
195+
196+ @registry .register_hparams
197+ def lstm_luong_attention_multi ():
198+ """ Multi-head Luong attention """
199+ hparams = lstm_luong_attention ()
200+ hparams .num_heads = 4
201+ return hparams
0 commit comments