From 3b831e0a21bf1dcc68371db386c628d7a7dd4ca2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 27 Sep 2025 22:34:24 -0700 Subject: [PATCH] internal change PiperOrigin-RevId: 812362672 --- init2winit/model_lib/rope_nanodo.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/init2winit/model_lib/rope_nanodo.py b/init2winit/model_lib/rope_nanodo.py index 309dba65..7c917b36 100644 --- a/init2winit/model_lib/rope_nanodo.py +++ b/init2winit/model_lib/rope_nanodo.py @@ -38,10 +38,10 @@ DEFAULT_HPARAMS = config_dict.ConfigDict( dict( - emb_dim=1024, # model/embed dim = qkv dim + emb_dim=512, # model/embed dim = qkv dim num_heads=8, # num attention heads num_layers=12, # number of transformer block layers - mlp_dim=1024, # FF inner dimension + mlp_dim=2048, # FF inner dimension rng_seed=-1, computation_dtype='bfloat16', model_dtype='float32', @@ -61,6 +61,7 @@ use_shallue_label_smoothing=False, normalization='rmsnorm', mlp_activation='glu', + qk_norm=True, ) ) @@ -85,6 +86,7 @@ class DoConfig: tie_embeddings: bool = True # Whether to tie input and output embeddings mlp_activation: str = 'glu' normalization: str = 'rmsnorm' + qk_norm: bool = True class Mlp(nn.Module): @@ -196,6 +198,8 @@ def setup(self): use_bias=False, dtype=cfg.dtype, ) + self.layer_norm_q = nn.LayerNorm(dtype=cfg.dtype, use_bias=False) + self.layer_norm_k = nn.LayerNorm(dtype=cfg.dtype, use_bias=False) def __call__(self, x_BxLxD: jax.Array): cfg = self.cfg @@ -205,6 +209,10 @@ def __call__(self, x_BxLxD: jax.Array): k_BxLxHxDh = self.multilinear_key(x_BxLxD) v_BxLxHxDh = self.multilinear_value(x_BxLxD) + if cfg.qk_norm: + q_BxLxHxDh = self.layer_norm_q(q_BxLxHxDh) + k_BxLxHxDh = self.layer_norm_k(k_BxLxHxDh) + # Apply rotary embeddings to Q and K q_BxLxHxDh, k_BxLxHxDh = apply_rope(q_BxLxHxDh, k_BxLxHxDh, self.freqs_cis) @@ -325,6 +333,7 @@ def build_flax_module(self): dtype=utils.dtype_from_str(self.hps['computation_dtype']), mlp_activation=self.hps['mlp_activation'], normalization=self.hps['normalization'], + qk_norm=self.hps['qk_norm'], ) return TransformerDo(config)