55import jax
66from flax import nnx
77from jax import numpy as jnp
8+ from jax .experimental .layout import Layout , with_layout_constraint
9+ from jax .sharding import NamedSharding , PartitionSpec
810
911
1012@dataclass (kw_only = True )
@@ -72,7 +74,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
7274 mscale_value : float = 1
7375 mscale_all_dim : float = 0
7476
75- def initialize_cache (self ):
77+ def initialize_cache (self , mesh : jax . sharding . Mesh ):
7678 """Computes and caches the sin/cos embeddings."""
7779 # The second condition is for the Qwix case, where we need to call `initialize_cache` on
7880 # the abstract model. Thus, when we go to call `initialize_cache` on the concrete model,
@@ -81,9 +83,11 @@ def initialize_cache(self):
8183 if self .sin_cos_cache is not None and not isinstance (
8284 self .sin_cos_cache , jax .ShapeDtypeStruct ):
8385 return
84- self . mscale = _yarn_get_mscale (
86+ mscale_val = _yarn_get_mscale (
8587 self .scaling_factor , self .mscale_value ) / _yarn_get_mscale (
8688 self .scaling_factor , self .mscale_all_dim )
89+ replicated_sharding = NamedSharding (mesh , PartitionSpec ())
90+ self .mscale = jax .device_put (mscale_val , replicated_sharding )
8791 self .sin_cos_cache = self ._compute_sin_cos ()
8892
8993 def _compute_inv_freq (self ):
@@ -103,6 +107,7 @@ def _compute_inv_freq(self):
103107 1 - inv_freq_mask ) + inv_freq_extrapolation * inv_freq_mask
104108 return inv_freq
105109
110+ @jax .jit
106111 def _compute_sin_cos (self ):
107112 inv_freq_H = self ._compute_inv_freq ()
108113 t = jnp .arange (self .original_max_position_embeddings *
@@ -111,12 +116,20 @@ def _compute_sin_cos(self):
111116 freqs = jnp .einsum ("...T,k->...Tk" , t , inv_freq_H )
112117 sin , cos = jnp .sin (freqs ) * self .mscale , jnp .cos (freqs ) * self .mscale
113118 cache = jnp .concatenate ((cos , sin ), axis = - 1 )
114- return cache
119+ H = cache .shape [1 ]
120+ target_dim = ((H - 1 ) // 128 + 1 ) * 128
121+ padding_amount = target_dim - self .rotary_dim
122+ pad_width = ((0 , 0 ), (0 , padding_amount ))
123+ cache_padded = jnp .pad (cache , pad_width , mode = 'constant' )
124+ desired_layout = Layout (major_to_minor = (1 , 0 ))
125+ cache_padded = with_layout_constraint (cache_padded , desired_layout )
126+ return cache_padded
115127
116128 def apply_rope (self , positions : jax .Array , x_TNH : jax .Array ):
117129 assert x_TNH .ndim == 3
118130 assert self .sin_cos_cache is not None , "RoPE cache not initialized."
119- cos_sin_TH = self .sin_cos_cache [positions ]
131+ cos_sin_padded = self .sin_cos_cache [positions ]
132+ cos_sin_TH = cos_sin_padded [:, :self .rotary_dim ]
120133 # cos, sin: (T, H/2)
121134 cos_TH , sin_TH = jnp .split (cos_sin_TH , 2 , axis = - 1 )
122135 assert sin_TH .ndim == 2 and cos_TH .ndim == 2
0 commit comments