diff --git a/init2winit/model_lib/rope_nanodo.py b/init2winit/model_lib/rope_nanodo.py index 7c917b36..82b35cf2 100644 --- a/init2winit/model_lib/rope_nanodo.py +++ b/init2winit/model_lib/rope_nanodo.py @@ -70,7 +70,7 @@ class DoConfig: """Hyper-parameters for Transformer decoder-only.""" - D: int # model/embed dim = qkv dim + D: int # embed dim H: int # num attention heads N: int # number of transformer block layers V: int # vocab size @@ -101,14 +101,19 @@ def __call__(self, x_BxLxD: jax.Array): linear = partial( nn.Dense, kernel_init=cfg.kernel_init, use_bias=False, dtype=cfg.dtype ) - if cfg.mlp_activation == 'glu': - mlp_activation = nn.glu - x_BxLxF = linear(cfg.F)(x_BxLxD) - elif cfg.mlp_activation == 'gelu': + if cfg.mlp_activation == 'gelu': mlp_activation = nn.gelu - hidden_dim = cfg.multiple_of * ( - (cfg.F + cfg.multiple_of - 1) // cfg.multiple_of - ) + x_BxLxF = linear(cfg.F)(x_BxLxD) + elif cfg.mlp_activation == 'glu': + mlp_activation = nn.glu + # Adjust hidden dimension to keep the number of parameters invariant to + # the activation function used since the GLU MLP has 3 * hidden_dim * D + # parameters instead of 2 * hidden_dim * D parameters. + hidden_dim = cfg.F * 2 / 3 + # Round up to the nearest multiple of cfg.multiple_of + hidden_dim = int(cfg.multiple_of * ( + (hidden_dim + cfg.multiple_of - 1) // cfg.multiple_of + )) # Double the hidden dimension for GLU x_BxLxF = linear(2 * hidden_dim)(x_BxLxD) else: