Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions init2winit/model_lib/rope_nanodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
class DoConfig:
"""Hyper-parameters for Transformer decoder-only."""

D: int # model/embed dim = qkv dim
D: int # embed dim
H: int # num attention heads
N: int # number of transformer block layers
V: int # vocab size
Expand Down Expand Up @@ -101,14 +101,19 @@ def __call__(self, x_BxLxD: jax.Array):
linear = partial(
nn.Dense, kernel_init=cfg.kernel_init, use_bias=False, dtype=cfg.dtype
)
if cfg.mlp_activation == 'glu':
mlp_activation = nn.glu
x_BxLxF = linear(cfg.F)(x_BxLxD)
elif cfg.mlp_activation == 'gelu':
if cfg.mlp_activation == 'gelu':
mlp_activation = nn.gelu
hidden_dim = cfg.multiple_of * (
(cfg.F + cfg.multiple_of - 1) // cfg.multiple_of
)
x_BxLxF = linear(cfg.F)(x_BxLxD)
elif cfg.mlp_activation == 'glu':
mlp_activation = nn.glu
# Adjust hidden dimension to keep the number of parameters invariant to
# the activation function used since the GLU MLP has 3 * hidden_dim * D
# parameters instead of 2 * hidden_dim * D parameters.
hidden_dim = cfg.F * 2 / 3
# Round up to the nearest multiple of cfg.multiple_of
hidden_dim = int(cfg.multiple_of * (
(hidden_dim + cfg.multiple_of - 1) // cfg.multiple_of
))
# Double the hidden dimension for GLU
x_BxLxF = linear(2 * hidden_dim)(x_BxLxD)
else:
Expand Down