|
1 | 1 | import enum |
2 | 2 | from dataclasses import InitVar, dataclass |
3 | 3 | from functools import partial |
4 | | -from typing import Tuple |
| 4 | +from typing import Optional, Tuple |
5 | 5 |
|
6 | 6 | import jax |
7 | 7 | import jax.numpy as jnp |
8 | 8 | from flax import nnx |
9 | 9 | from flax.typing import Sharding |
10 | 10 | from jax.sharding import PartitionSpec |
11 | 11 | from jaxtyping import Float |
| 12 | +from qwix._src.core.ragged_dot import ragged_dot as qwix_ragged_dot |
| 13 | +from qwix._src.providers import ptq |
12 | 14 |
|
13 | 15 | from tpu_commons.models.jax.common.base import create_param |
14 | 16 | from tpu_commons.models.jax.common.layers import FlaxUtils |
15 | 17 | from tpu_commons.models.jax.common.moe.moe import MoE |
| 18 | +from tpu_commons.models.jax.utils.quantization.quantization_utils import ( |
| 19 | + manually_quantize_qwix_activation, manually_quantize_qwix_weight) |
16 | 20 |
|
17 | 21 | modeling_flax_utils = FlaxUtils() |
18 | 22 |
|
@@ -141,6 +145,8 @@ class SparseMoE(MoE): |
141 | 145 | tile_size: tuple[int, int, int] = (128, 64, 128) |
142 | 146 | use_megablox: bool = False |
143 | 147 | mesh: jax.sharding.Mesh |
| 148 | + # This should be set if and only if you have quantized your model (via Qwix) |
| 149 | + quantized_dtype: Optional[jnp.dtype] = None |
144 | 150 |
|
145 | 151 | def __post_init__(self, rngs: nnx.Rngs): |
146 | 152 | super().__post_init__(rngs) |
@@ -348,7 +354,11 @@ def _gmm(self, inputs, kernel, group_sizes): |
348 | 354 | raise NotImplementedError( |
349 | 355 | "MegaBlox kernel call is not implemented.") |
350 | 356 | else: |
351 | | - output = jax.lax.ragged_dot( |
| 357 | + inputs = manually_quantize_qwix_activation( |
| 358 | + inputs, "ragged_dot", jnp.float8_e4m3fn, [0], {}, |
| 359 | + "absmax") if self.quantized_dtype else inputs |
| 360 | + ragged_dot_func = qwix_ragged_dot if self.quantized_dtype else jax.lax.ragged_dot |
| 361 | + output = ragged_dot_func( |
352 | 362 | lhs=inputs, |
353 | 363 | rhs=kernel, |
354 | 364 | group_sizes=group_sizes, |
@@ -572,12 +582,27 @@ def __call__(self, x_TD: Float): |
572 | 582 | check_rep=False)( |
573 | 583 | SparseMoE._distributed_sparse_moe_fwd) |
574 | 584 |
|
575 | | - return mapped_moe_fwd( |
576 | | - self, |
577 | | - x_TD, |
578 | | - router_weights_TX, |
579 | | - selected_experts_TX, |
580 | | - self.kernel_gating_EDF.value, |
581 | | - self.kernel_up_proj_EDF.value, |
582 | | - self.kernel_down_proj_EFD.value, |
583 | | - ) |
| 585 | + kernel_gating_EDF = self.kernel_gating_EDF.value |
| 586 | + kernel_up_proj_EDF = self.kernel_up_proj_EDF.value |
| 587 | + kernel_down_proj_EFD = self.kernel_down_proj_EFD.value |
| 588 | + |
| 589 | + if self.quantized_dtype: |
| 590 | + if not isinstance(kernel_gating_EDF, ptq.WithAux): |
| 591 | + kernel_gating_EDF = manually_quantize_qwix_weight( |
| 592 | + kernel_gating_EDF, self.quantized_dtype, [0, 2], {}, |
| 593 | + "absmax") |
| 594 | + if not isinstance(kernel_up_proj_EDF, ptq.WithAux): |
| 595 | + kernel_up_proj_EDF = manually_quantize_qwix_weight( |
| 596 | + kernel_up_proj_EDF, self.quantized_dtype, [0, 2], {}, |
| 597 | + "absmax") |
| 598 | + if not isinstance(kernel_down_proj_EFD, ptq.WithAux): |
| 599 | + kernel_down_proj_EFD = manually_quantize_qwix_weight( |
| 600 | + kernel_down_proj_EFD, self.quantized_dtype, [0, 1], {}, |
| 601 | + "absmax") |
| 602 | + kernel_gating_EDF = kernel_gating_EDF.array |
| 603 | + kernel_up_proj_EDF = kernel_up_proj_EDF.array |
| 604 | + kernel_down_proj_EFD = kernel_down_proj_EFD.array |
| 605 | + |
| 606 | + return mapped_moe_fwd(self, x_TD, router_weights_TX, |
| 607 | + selected_experts_TX, kernel_gating_EDF, |
| 608 | + kernel_up_proj_EDF, kernel_down_proj_EFD) |
0 commit comments