Skip to content

Commit e88f387

Browse files
bzgooglebzgoogle
andauthored
[DeepSeek] Optimize RoPE Cache to remove re-layout (#1073)
Signed-off-by: bzgoogle <beinuoz_google_com@t1v-n-fa0da4f0-w-0.us-central1-c.c.cloud-tpu-inference-test.internal> Co-authored-by: bzgoogle <beinuoz_google_com@t1v-n-fa0da4f0-w-0.us-central1-c.c.cloud-tpu-inference-test.internal>
1 parent 8e19d70 commit e88f387

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

tpu_inference/layers/jax/rope.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import jax
66
from flax import nnx
77
from jax import numpy as jnp
8+
from jax.experimental.layout import Layout, with_layout_constraint
9+
from jax.sharding import NamedSharding, PartitionSpec
810

911

1012
@dataclass(kw_only=True)
@@ -72,7 +74,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
7274
mscale_value: float = 1
7375
mscale_all_dim: float = 0
7476

75-
def initialize_cache(self):
77+
def initialize_cache(self, mesh: jax.sharding.Mesh):
7678
"""Computes and caches the sin/cos embeddings."""
7779
# The second condition is for the Qwix case, where we need to call `initialize_cache` on
7880
# the abstract model. Thus, when we go to call `initialize_cache` on the concrete model,
@@ -81,9 +83,11 @@ def initialize_cache(self):
8183
if self.sin_cos_cache is not None and not isinstance(
8284
self.sin_cos_cache, jax.ShapeDtypeStruct):
8385
return
84-
self.mscale = _yarn_get_mscale(
86+
mscale_val = _yarn_get_mscale(
8587
self.scaling_factor, self.mscale_value) / _yarn_get_mscale(
8688
self.scaling_factor, self.mscale_all_dim)
89+
replicated_sharding = NamedSharding(mesh, PartitionSpec())
90+
self.mscale = jax.device_put(mscale_val, replicated_sharding)
8791
self.sin_cos_cache = self._compute_sin_cos()
8892

8993
def _compute_inv_freq(self):
@@ -103,6 +107,7 @@ def _compute_inv_freq(self):
103107
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
104108
return inv_freq
105109

110+
@jax.jit
106111
def _compute_sin_cos(self):
107112
inv_freq_H = self._compute_inv_freq()
108113
t = jnp.arange(self.original_max_position_embeddings *
@@ -111,12 +116,20 @@ def _compute_sin_cos(self):
111116
freqs = jnp.einsum("...T,k->...Tk", t, inv_freq_H)
112117
sin, cos = jnp.sin(freqs) * self.mscale, jnp.cos(freqs) * self.mscale
113118
cache = jnp.concatenate((cos, sin), axis=-1)
114-
return cache
119+
H = cache.shape[1]
120+
target_dim = ((H - 1) // 128 + 1) * 128
121+
padding_amount = target_dim - self.rotary_dim
122+
pad_width = ((0, 0), (0, padding_amount))
123+
cache_padded = jnp.pad(cache, pad_width, mode='constant')
124+
desired_layout = Layout(major_to_minor=(1, 0))
125+
cache_padded = with_layout_constraint(cache_padded, desired_layout)
126+
return cache_padded
115127

116128
def apply_rope(self, positions: jax.Array, x_TNH: jax.Array):
117129
assert x_TNH.ndim == 3
118130
assert self.sin_cos_cache is not None, "RoPE cache not initialized."
119-
cos_sin_TH = self.sin_cos_cache[positions]
131+
cos_sin_padded = self.sin_cos_cache[positions]
132+
cos_sin_TH = cos_sin_padded[:, :self.rotary_dim]
120133
# cos, sin: (T, H/2)
121134
cos_TH, sin_TH = jnp.split(cos_sin_TH, 2, axis=-1)
122135
assert sin_TH.ndim == 2 and cos_TH.ndim == 2

tpu_inference/models/jax/deepseek_v3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None):
317317
# we have to pass dynamic arrays here for __call__'s usage.
318318
self.rng = nnx.Rngs(rng)
319319
self.weight_loader.load_weights(self)
320-
self.initialize_cache()
320+
self.initialize_cache(self.mesh)
321321

322322
def initialize_cache(self):
323323
# Initialize RoPE caches after weights are loaded and before JIT compilation.

0 commit comments

Comments
 (0)