@@ -65,11 +65,12 @@ def n_outputs(self):
6565 """Specifies how many data tensors this layer promises as output."""
6666 return self ._n_sections
6767
68- def call (self , inputs , params = (), ** kwargs ):
68+ def call (self , inputs , params = (), state = (), ** kwargs ):
6969 rngs = _pop_rng_and_split (kwargs , len (inputs ))
70- result = [self ._layer (x , params = params , rng = r , ** kwargs )
71- for x , r in zip (inputs , rngs )]
72- return tuple (result )
70+ results = [self ._layer (x , params = params , state = state , rng = r , ** kwargs )
71+ for x , r in zip (inputs , rngs )]
72+ result_outputs , result_states = zip (* results )
73+ return tuple (result_outputs ), tuple (result_states )
7374
7475 def new_parameters (self , input_shape , input_dtype , rng ):
7576 first_shape = input_shape [0 ]
@@ -122,12 +123,13 @@ def __init__(self, n_sections=2, axis=-1):
122123 self ._n_sections = n_sections
123124 self ._axis = axis
124125
125- def call (self , inputs , params = (), ** kwargs ):
126+ def call (self , inputs , params = (), state = (), ** kwargs ):
126127 del params , kwargs
127- return tuple (backend .numpy .split (inputs , self ._n_sections , self ._axis ))
128+ res = tuple (backend .numpy .split (inputs , self ._n_sections , self ._axis ))
129+ return res , state
128130
129131 def new_parameters (self , input_shapes , input_dtype , rng ):
130- return ()
132+ return (), ()
131133
132134 def n_inputs (self ):
133135 """Specifies how many data tensors this layer expects as input."""
@@ -167,17 +169,17 @@ def n_outputs(self):
167169 return self ._n_sections
168170
169171 def new_parameters (self , input_shape , input_dtype , rng ):
170- return ()
172+ return (), ()
171173
172- def call (self , inputs , params = (), ** kwargs ):
174+ def call (self , inputs , params = (), state = (), ** kwargs ):
173175 del params , kwargs
174176 x1 , x2 = inputs
175177
176178 x1_split = backend .numpy .split (x1 , self ._n_sections , self ._axis )
177179 x2_split = backend .numpy .split (x2 , self ._n_sections , self ._axis )
178180
179181 res = [backend .numpy .concatenate (ys , - 1 ) for ys in zip (x1_split , x2_split )]
180- return tuple (res )
182+ return tuple (res ), state
181183
182184 def reverse (self , output , params = (), ** kwargs ):
183185 del params , kwargs
@@ -288,7 +290,7 @@ def __init__(self, n_heads=1, d_head=64,
288290 # The lack of a bias term here is consistent with the tensor2tensor
289291 # implementation, and shouldn't have an effect on modeling quality.
290292
291- def call (self , x , params , ** kwargs ):
293+ def call (self , x , params , state , ** kwargs ):
292294 del kwargs
293295 seqlen = x .shape [1 ]
294296 res = np .dot (x , params )
@@ -300,13 +302,13 @@ def call(self, x, params, **kwargs):
300302 # n_batch, n_heads, seqlen, d_head -> n_batch*n_heads, seqlen, d_head
301303 res = np .reshape (res , (- 1 , seqlen , self ._d_head ))
302304
303- return res
305+ return res , state
304306
305307 def new_parameters (self , input_shape , input_dtype , rng ):
306308 del input_dtype
307309 w = self ._kernel_initializer (
308310 (input_shape [- 1 ], self ._n_heads * self ._d_head ), rng )
309- return w
311+ return w , ()
310312
311313
312314class ComputeAttentionOutput (tl .Layer ):
@@ -321,7 +323,7 @@ def __init__(self, n_heads=1, d_model=1024,
321323 # The lack of a bias term here is consistent with the tensor2tensor
322324 # implementation, and shouldn't have an effect on modeling quality.
323325
324- def call (self , x , params , ** kwargs ):
326+ def call (self , x , params , state , ** kwargs ):
325327 del kwargs
326328 seqlen = x .shape [1 ]
327329 d_head = x .shape [2 ]
@@ -330,13 +332,13 @@ def call(self, x, params, **kwargs):
330332 x = np .transpose (x , (0 , 2 , 1 , 3 )) # -> n_batch, seqlen, n_heads, d_head
331333 x = np .reshape (x , (- 1 , seqlen , self ._n_heads * d_head ))
332334
333- return np .dot (x , params )
335+ return np .dot (x , params ), state
334336
335337 def new_parameters (self , input_shape , input_dtype , rng ):
336338 del input_dtype
337339 w = self ._kernel_initializer (
338340 (input_shape [- 1 ] * self ._n_heads , self ._d_model ), rng )
339- return w
341+ return w , ()
340342
341343
342344class ApplyAttentionWrapper (tl .Parallel ):
@@ -374,14 +376,14 @@ def __init__(self, dropout, mode):
374376 self ._dropout = dropout
375377 self ._mode = mode
376378
377- def call (self , inputs , params = (), rng = None , ** kwargs ):
379+ def call (self , inputs , params = (), state = (), rng = None , ** kwargs ):
378380 del params
379381 q , k , v = inputs
380382 mask_size = q .shape [- 2 ]
381383 mask = np .tril (np .ones ((1 , mask_size , mask_size ), dtype = onp .bool_ ), k = 0 )
382384 res = tl .DotProductAttention (
383385 q , k , v , mask , dropout = self ._dropout , mode = self ._mode , rng = rng )
384- return res
386+ return res , state
385387
386388 def forward_and_vjp (self , inputs , ct , params = (), ** kwargs ):
387389 # Simultaneous forward pass and backprop through the attention mechanism.
@@ -391,7 +393,7 @@ def do_call(x):
391393 return output , vjpfun (ct )[0 ]
392394
393395 def new_parameters (self , input_shapes , input_dtype , rng ):
394- return ()
396+ return (), ()
395397
396398 def n_inputs (self ):
397399 return 3
@@ -413,9 +415,9 @@ def __init__(self, loop_stride, dropout, mode):
413415 else :
414416 self .dropout = None
415417
416- def call (self , inputs , params = (), ** kwargs ):
418+ def call (self , inputs , params = (), state = (), ** kwargs ):
417419 output , _ = self .forward_and_vjp (inputs , None , params = params , ** kwargs )
418- return output
420+ return output , state
419421
420422 def forward_and_vjp (self , inputs , ct , params = (), rng = None , ** kwargs ):
421423 # This is the core of the memory-efficient attention implementation, where
@@ -547,9 +549,9 @@ def __init__(self, dropout, mode, n_bins=64):
547549 super (DummyHashedAttention , self ).__init__ (dropout , mode )
548550 self .n_bins = n_bins
549551
550- def call (self , inputs , params = (), ** kwargs ):
552+ def call (self , inputs , params = (), state = (), ** kwargs ):
551553 output , _ = self .forward_and_vjp (inputs , None , params = params , ** kwargs )
552- return output
554+ return output , state
553555
554556 def forward_and_vjp (self , inputs , ct , params = (), ** kwargs ):
555557 del params , kwargs
0 commit comments