Skip to content

Commit aed9b57

Browse files
jrplatinbzgoogle
authored andcommitted
[JAX][Quantization] Add Qwix support for SparseMatul (#740)
Signed-off-by: Jacob Platin <jacobplatin@google.com>
1 parent a171d3e commit aed9b57

File tree

2 files changed

+36
-12
lines changed

2 files changed

+36
-12
lines changed

tests/models/jax/common/moe/test_deepseek_moe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import unittest
32

43
import jax

tpu_commons/models/jax/common/moe/deepseek_moe.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
import enum
22
from dataclasses import InitVar, dataclass
33
from functools import partial
4-
from typing import Tuple
4+
from typing import Optional, Tuple
55

66
import jax
77
import jax.numpy as jnp
88
from flax import nnx
99
from flax.typing import Sharding
1010
from jax.sharding import PartitionSpec
1111
from jaxtyping import Float
12+
from qwix._src.core.ragged_dot import ragged_dot as qwix_ragged_dot
13+
from qwix._src.providers import ptq
1214

1315
from tpu_commons.models.jax.common.base import create_param
1416
from tpu_commons.models.jax.common.layers import FlaxUtils
1517
from tpu_commons.models.jax.common.moe.moe import MoE
18+
from tpu_commons.models.jax.utils.quantization.quantization_utils import (
19+
manually_quantize_qwix_activation, manually_quantize_qwix_weight)
1620

1721
modeling_flax_utils = FlaxUtils()
1822

@@ -141,6 +145,8 @@ class SparseMoE(MoE):
141145
tile_size: tuple[int, int, int] = (128, 64, 128)
142146
use_megablox: bool = False
143147
mesh: jax.sharding.Mesh
148+
# This should be set if and only if you have quantized your model (via Qwix)
149+
quantized_dtype: Optional[jnp.dtype] = None
144150

145151
def __post_init__(self, rngs: nnx.Rngs):
146152
super().__post_init__(rngs)
@@ -348,7 +354,11 @@ def _gmm(self, inputs, kernel, group_sizes):
348354
raise NotImplementedError(
349355
"MegaBlox kernel call is not implemented.")
350356
else:
351-
output = jax.lax.ragged_dot(
357+
inputs = manually_quantize_qwix_activation(
358+
inputs, "ragged_dot", jnp.float8_e4m3fn, [0], {},
359+
"absmax") if self.quantized_dtype else inputs
360+
ragged_dot_func = qwix_ragged_dot if self.quantized_dtype else jax.lax.ragged_dot
361+
output = ragged_dot_func(
352362
lhs=inputs,
353363
rhs=kernel,
354364
group_sizes=group_sizes,
@@ -572,12 +582,27 @@ def __call__(self, x_TD: Float):
572582
check_rep=False)(
573583
SparseMoE._distributed_sparse_moe_fwd)
574584

575-
return mapped_moe_fwd(
576-
self,
577-
x_TD,
578-
router_weights_TX,
579-
selected_experts_TX,
580-
self.kernel_gating_EDF.value,
581-
self.kernel_up_proj_EDF.value,
582-
self.kernel_down_proj_EFD.value,
583-
)
585+
kernel_gating_EDF = self.kernel_gating_EDF.value
586+
kernel_up_proj_EDF = self.kernel_up_proj_EDF.value
587+
kernel_down_proj_EFD = self.kernel_down_proj_EFD.value
588+
589+
if self.quantized_dtype:
590+
if not isinstance(kernel_gating_EDF, ptq.WithAux):
591+
kernel_gating_EDF = manually_quantize_qwix_weight(
592+
kernel_gating_EDF, self.quantized_dtype, [0, 2], {},
593+
"absmax")
594+
if not isinstance(kernel_up_proj_EDF, ptq.WithAux):
595+
kernel_up_proj_EDF = manually_quantize_qwix_weight(
596+
kernel_up_proj_EDF, self.quantized_dtype, [0, 2], {},
597+
"absmax")
598+
if not isinstance(kernel_down_proj_EFD, ptq.WithAux):
599+
kernel_down_proj_EFD = manually_quantize_qwix_weight(
600+
kernel_down_proj_EFD, self.quantized_dtype, [0, 1], {},
601+
"absmax")
602+
kernel_gating_EDF = kernel_gating_EDF.array
603+
kernel_up_proj_EDF = kernel_up_proj_EDF.array
604+
kernel_down_proj_EFD = kernel_down_proj_EFD.array
605+
606+
return mapped_moe_fwd(self, x_TD, router_weights_TX,
607+
selected_experts_TX, kernel_gating_EDF,
608+
kernel_up_proj_EDF, kernel_down_proj_EFD)

0 commit comments

Comments
 (0)