@@ -148,6 +148,8 @@ def print_shape(x, suffix, debug=False):
148148 hparams .hidden_size ,
149149 hparams .num_heads ,
150150 hparams .attention_dropout ,
151+ attention_type = ("local_mask_right" if hparams .attention_local
152+ else "dot_product" ),
151153 name = "decoder_self_attention" )
152154 elif hparams .attention_type == AttentionType .MEMORY_EFFICIENT :
153155 assert hparams .layer_preprocess_sequence == "n"
@@ -349,6 +351,7 @@ def attention_lm_moe_base():
349351 hparams .add_hparam ("moe_layers" , "2" ) # comma separated list of layer numbers
350352 # moe params. local attention moe.
351353 hparams .add_hparam ("attention_type" , AttentionType .MULTIHEAD )
354+ hparams .add_hparam ("attention_local" , int (False ))
352355 hparams .add_hparam ("attention_moe_k" , 2 )
353356 hparams .add_hparam ("attention_num_experts" , 16 )
354357 hparams .add_hparam ("attention_split_batch" , int (False ))
@@ -383,6 +386,18 @@ def attention_lm_moe_base_ae():
383386 return hparams
384387
385388
389+ @registry .register_hparams
390+ def attention_lm_moe_base_local ():
391+ """Base model with attention expert."""
392+ hparams = attention_lm_moe_base ()
393+ hparams .attention_local = int (True )
394+ hparams .use_sepconv = int (True )
395+ hparams .max_length = 0 # max_length == batch_size
396+ hparams .eval_drop_long_sequences = int (True )
397+ hparams .min_length_bucket = 256 # Avoid cyclic problems for big batches
398+ return hparams
399+
400+
386401@registry .register_hparams
387402def attention_lm_moe_small ():
388403 """Cheap model for single-gpu training.
0 commit comments