Skip to content
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ NVIDIA Model Optimizer Changelog (Linux)

**New Features**

- Add support for **Flux.2-dev** (`black-forest-labs/FLUX.2-dev`) quantization (INT8, FP8, NVFP4) and ONNX export.
- Add quantization support for ``Flux2Attention`` and ``Flux2ParallelSelfAttention`` layers in the Diffusers plugin.
- Add support for Transformer Engine quantization for Megatron Core models.
- Add support for Qwen3-Next model quantization.
- Add support for dynamically linked TensorRT plugins in the ONNX quantization workflow.
Expand Down
19 changes: 14 additions & 5 deletions examples/diffusers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ mtq.quantize(model=transformer, config=quant_config, forward_func=forward_pass)
| Model | fp8 | int8_sq | int4_awq | w4a8_awq<sup>1</sup> | nvfp4<sup>2</sup> | nvfp4_svdquant<sup>3</sup> | Cache Diffusion |
| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | - |
| [FLUX 2](https://huggingface.co/black-forest-labs/FLUX.2-dev) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | - |
| [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | - |
Expand Down Expand Up @@ -103,24 +104,24 @@ bash build_sdxl_8bit_engine.sh --format {FORMAT} # FORMAT can be int8 or fp8

If you prefer to customize parameters in calibration or run other models, please follow the instructions below.

#### FLUX-Dev|SD3-Medium|SDXL|SDXL-Turbo INT8 [Script](./quantization/quantize.py)
#### FLUX-Dev|FLUX-2-Dev|SD3-Medium|SDXL|SDXL-Turbo INT8 [Script](./quantization/quantize.py)

```sh
python quantize.py \
--model {flux-dev|sdxl-1.0|sdxl-turbo|sd3-medium} \
--model {flux-dev|fux-2-dev|sdxl-1.0|sdxl-turbo|sd3-medium} \
--format int8 --batch-size 2 \
--calib-size 32 --alpha 0.8 --n-steps 20 \
--model-dtype {Half/BFloat16} --trt-high-precision-dtype {Half|BFloat16} \
--quantized-torch-ckpt-save-path ./{MODEL_NAME}.pt --onnx-dir {ONNX_DIR}
```

#### FLUX-Dev|SDXL|SDXL-Turbo|LTX-Video FP8/FP4 [Script](./quantization/quantize.py)
#### FLUX-Dev|FLUX-2-Dev|SDXL|SDXL-Turbo|LTX-Video FP8/FP4 [Script](./quantization/quantize.py)

*In our example code, FP4 is only supported for Flux. However, you can modify our script to enable FP4 format support for your own model.*

```sh
python quantize.py \
--model {flux-dev|sdxl-1.0|sdxl-turbo|ltx-video-dev} --model-dtype {Half|BFloat16} --trt-high-precision-dtype {Half|BFloat16} \
--model {flux-dev|flux-2-dev|sdxl-1.0|sdxl-turbo|ltx-video-dev} --model-dtype {Half|BFloat16} --trt-high-precision-dtype {Half|BFloat16} \
--format {fp8|fp4} --batch-size 2 --calib-size {128|256} --quantize-mha \
--n-steps 20 --quantized-torch-ckpt-save-path ./{MODEL_NAME}.pt --collect-method default \
--onnx-dir {ONNX_DIR}
Expand Down Expand Up @@ -252,6 +253,14 @@ trtexec --onnx=./model.onnx --fp8 --bf16 --stronglyTyped \
--optShapes=hidden_states:1x4096x64,img_ids:4096x3,encoder_hidden_states:1x512x4096,txt_ids:512x3,timestep:1,pooled_projections:1x768,guidance:1 \
--maxShapes=hidden_states:1x4096x64,img_ids:4096x3,encoder_hidden_states:1x512x4096,txt_ids:512x3,timestep:1,pooled_projections:1x768,guidance:1 \
--saveEngine=model.plan

# # For FLUX-2-Dev FP8
trtexec --onnx=./model.onnx --fp8 --bf16 --stronglyTyped \
--minShapes=hidden_states:1x4096x128,img_ids:4096x4,encoder_hidden_states:1x512x15360,txt_ids:512x4,timestep:1,guidance:1 \
--optShapes=hidden_states:1x4096x128,img_ids:4096x4,encoder_hidden_states:1x512x15360,txt_ids:512x4,timestep:1,guidance:1 \
--maxShapes=hidden_states:1x4096x128,img_ids:4096x4,encoder_hidden_states:1x512x15360,txt_ids:512x4,timestep:1,guidance:1 \
--saveEngine=model.plan

```

**Please note that `maxShapes` represents the maximum shape of the given tensor. If you want to use a larger batch size or any other dimensions, feel free to adjust the value accordingly.**
Expand Down Expand Up @@ -293,7 +302,7 @@ Generate a quantized torch checkpoint using the [Script](./quantization/quantize

```bash
python quantize.py \
--model {sdxl-1.0|sdxl-turbo|sd3-medium|flux-dev} \
--model {sdxl-1.0|sdxl-turbo|sd3-medium|flux-dev|flux-2-dev} \
--format fp8 \
--batch-size {1|2} \
--calib-size 128 \
Expand Down
4 changes: 3 additions & 1 deletion examples/diffusers/quantization/diffusion_trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"sdxl-turbo": ModelType.SDXL_TURBO,
"sd3-medium": ModelType.SD3_MEDIUM,
"flux-dev": ModelType.FLUX_DEV,
"flux-2-dev": ModelType.FLUX_2_DEV,
"flux-schnell": ModelType.FLUX_SCHNELL,
}

Expand All @@ -60,6 +61,7 @@
"sdxl-turbo": torch.float16,
"sd3-medium": torch.float16,
"flux-dev": torch.bfloat16,
"flux-2-dev": torch.bfloat16,
"flux-schnell": torch.bfloat16,
}

Expand Down Expand Up @@ -142,7 +144,7 @@ def main():
"--model",
type=str,
default="flux-dev",
choices=["sdxl-1.0", "sdxl-turbo", "sd3-medium", "flux-dev", "flux-schnell"],
choices=["sdxl-1.0", "sdxl-turbo", "sd3-medium", "flux-dev", "flux-schnell", "flux-2-dev"],
)
parser.add_argument(
"--override-model-path",
Expand Down
92 changes: 89 additions & 3 deletions examples/diffusers/quantization/onnx_utils/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@
"guidance": {0: "batch_size"},
"latent": {0: "batch_size"},
},
"flux-2-dev": {
"hidden_states": {0: "batch_size", 1: "latent_dim"},
"encoder_hidden_states": {0: "batch_size"},
"timestep": {0: "batch_size"},
"img_ids": {0: "latent_dim"},
"guidance": {0: "batch_size"},
"latent": {0: "batch_size"}
},
"flux-schnell": {
"hidden_states": {0: "batch_size", 1: "latent_dim"},
"encoder_hidden_states": {0: "batch_size"},
Expand Down Expand Up @@ -265,6 +273,70 @@ def _gen_dummy_inp_and_dyn_shapes_flux(backbone, min_bs=1, opt_bs=1):
return dummy_kwargs, dynamic_shapes


def _gen_dummy_inp_and_dyn_shapes_flux2(backbone, min_bs=1, opt_bs=1, max_bs=1):
assert isinstance(backbone, Flux2Transformer2DModel) or isinstance(
backbone._orig_mod, Flux2Transformer2DModel
)
cfg = backbone.config

text_maxlen = 512 # constant

min_img_dim = 256 # 256x256
opt_img_dim = 4096 # 1024x1024
max_img_dim = 16384 # 2048x2048

rope_dim = 4

dynamic_shapes = {
"hidden_states": {
"min": [min_bs, min_img_dim, cfg.in_channels],
"opt": [opt_bs, opt_img_dim, cfg.in_channels],
"max": [max_bs, max_img_dim, cfg.in_channels],
},
"encoder_hidden_states": {
"min": [min_bs, text_maxlen, cfg.joint_attention_dim],
"opt": [opt_bs, text_maxlen, cfg.joint_attention_dim],
"max": [max_bs, text_maxlen, cfg.joint_attention_dim],
},
"timestep": {
"min": [min_bs],
"opt": [opt_bs],
"max": [max_bs]
},
"guidance": {
"min": [min_bs],
"opt": [opt_bs],
"max": [max_bs]
},
"img_ids": {
"min": [min_img_dim, rope_dim],
"opt": [opt_img_dim, rope_dim],
"max": [max_img_dim, rope_dim]
},
"txt_ids": {
"min": [text_maxlen, rope_dim],
"opt": [text_maxlen, rope_dim],
"max": [text_maxlen, rope_dim]
},
}


dtype = backbone.dtype
device = backbone.device

dummy_kwargs = {
"hidden_states": torch.randn(*dynamic_shapes["hidden_states"]["opt"], dtype=dtype, device=device),
"encoder_hidden_states": torch.randn(*dynamic_shapes["encoder_hidden_states"]["opt"], dtype=dtype, device=device),
"timestep": torch.ones(*dynamic_shapes["timestep"]["opt"], dtype=dtype, device=device),
"img_ids": torch.randn(*dynamic_shapes["img_ids"]["opt"], dtype=dtype, device=device),
"txt_ids": torch.randn(*dynamic_shapes["txt_ids"]["opt"], dtype=dtype, device=device),
"guidance": torch.full(dynamic_shapes["guidance"]["opt"], 3.5, dtype=dtype, device=device),
"return_dict": False
}

return dummy_kwargs, dynamic_shapes


def _gen_dummy_inp_and_dyn_shapes_ltx(backbone, min_bs=2, opt_bs=2):
assert isinstance(backbone, LTXVideoTransformer3DModel) or isinstance(
backbone._orig_mod, LTXVideoTransformer3DModel
Expand Down Expand Up @@ -356,7 +428,7 @@ def _gen_dummy_inp_and_dyn_shapes_wan(backbone, min_bs=1, opt_bs=2):


def update_dynamic_axes(model_id, dynamic_axes):
if model_id in ["flux-dev", "flux-schnell"]:
if model_id in ["flux-dev", "flux-schnell", "flux-2-dev"]:
dynamic_axes["out.0"] = dynamic_axes.pop("latent")
elif model_id in ["sdxl-1.0", "sdxl-turbo"]:
dynamic_axes["added_cond_kwargs.text_embeds"] = dynamic_axes.pop("text_embeds")
Expand Down Expand Up @@ -395,6 +467,10 @@ def generate_dummy_kwargs_and_dynamic_axes_and_shapes(model_id, backbone):
dummy_kwargs, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_flux(
backbone, min_bs=1, opt_bs=1
)
elif model_id == "flux-2-dev":
dummy_kwargs, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_flux2(
backbone, min_bs=1, opt_bs=1
)
elif model_id == "ltx-video-dev":
dummy_kwargs, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_ltx(
backbone, min_bs=2, opt_bs=2
Expand All @@ -421,7 +497,7 @@ def get_io_shapes(model_id, onnx_load_path, trt_dynamic_shapes):
output_name = "sample"
elif model_id in ["sd3.5-medium"]:
output_name = "out_hidden_states"
elif model_id in ["flux-dev", "flux-schnell"]:
elif model_id in ["flux-dev", "flux-schnell", "flux-2-dev"]:
output_name = "output"
else:
raise NotImplementedError(f"Unsupported model_id: {model_id}")
Expand All @@ -430,7 +506,7 @@ def get_io_shapes(model_id, onnx_load_path, trt_dynamic_shapes):
io_shapes = {output_name: trt_dynamic_shapes["minShapes"]["sample"]}
elif model_id in ["sd3-medium", "sd3.5-medium"]:
io_shapes = {output_name: trt_dynamic_shapes["minShapes"]["hidden_states"]}
elif model_id in ["flux-dev", "flux-schnell"]:
elif model_id in ["flux-dev", "flux-schnell", "flux-2-dev"]:
io_shapes = {}

return io_shapes
Expand Down Expand Up @@ -499,6 +575,16 @@ def modelopt_export_sd(backbone, onnx_dir, model_name, precision):
if model_name == "flux-dev":
input_names.append("guidance")
output_names = ["latent"]
elif model_name == "flux-2-dev":
input_names = [
"hidden_states",
"encoder_hidden_states",
"timestep",
"img_ids",
"txt_ids",
"guidance",
]
output_names = ["latent"]
elif model_name in ["ltx-video-dev"]:
input_names = [
"hidden_states",
Expand Down
19 changes: 19 additions & 0 deletions examples/diffusers/quantization/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from diffusers import (
DiffusionPipeline,
FluxPipeline,
Flux2Pipeline,
LTXConditionPipeline,
LTXLatentUpsamplePipeline,
StableDiffusion3Pipeline,
Expand Down Expand Up @@ -77,6 +78,7 @@ class ModelType(str, Enum):
SD35_MEDIUM = "sd3.5-medium"
FLUX_DEV = "flux-dev"
FLUX_SCHNELL = "flux-schnell"
FLUX_2_DEV = "flux-2-dev"
LTX_VIDEO_DEV = "ltx-video-dev"
WAN22_T2V = "wan2.2-t2v-14b"

Expand Down Expand Up @@ -138,6 +140,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
"""
filter_func_map = {
ModelType.FLUX_DEV: filter_func_default,
ModelType.FLUX_2_DEV: filter_func_default,
ModelType.FLUX_SCHNELL: filter_func_default,
ModelType.SDXL_BASE: filter_func_default,
ModelType.SDXL_TURBO: filter_func_default,
Expand All @@ -157,6 +160,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
ModelType.SD3_MEDIUM: "stabilityai/stable-diffusion-3-medium-diffusers",
ModelType.SD35_MEDIUM: "stabilityai/stable-diffusion-3.5-medium",
ModelType.FLUX_DEV: "black-forest-labs/FLUX.1-dev",
ModelType.FLUX_2_DEV: "black-forest-labs/FLUX.2-dev",
ModelType.FLUX_SCHNELL: "black-forest-labs/FLUX.1-schnell",
ModelType.LTX_VIDEO_DEV: "Lightricks/LTX-Video-0.9.7-dev",
ModelType.WAN22_T2V: "Wan-AI/Wan2.2-T2V-A14B-Diffusers",
Expand All @@ -168,6 +172,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
ModelType.SD3_MEDIUM: StableDiffusion3Pipeline,
ModelType.SD35_MEDIUM: StableDiffusion3Pipeline,
ModelType.FLUX_DEV: FluxPipeline,
ModelType.FLUX_DEV: Flux2Pipeline,
ModelType.FLUX_SCHNELL: FluxPipeline,
ModelType.LTX_VIDEO_DEV: LTXConditionPipeline,
ModelType.WAN22_T2V: WanPipeline,
Expand Down Expand Up @@ -221,6 +226,20 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]:
"max_sequence_length": 512,
},
},
ModelType.FLUX_2_DEV: {
"backbone": "transformer",
"dataset": {
"name": "nateraw/parti-prompts",
"split": "train",
"column": "Prompt",
},
"inference_extra_args": {
"height": 1024,
"width": 1024,
"guidance_scale": 3.5,
"max_sequence_length": 512,
},
},
ModelType.FLUX_SCHNELL: {
"backbone": "transformer",
"dataset": {
Expand Down
5 changes: 5 additions & 0 deletions modelopt/torch/quantization/plugins/diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_dispatch import AttentionBackendName, attention_backend
from diffusers.models.transformers.transformer_flux import FluxAttention
from diffusers.models.transformers.transformer_flux2 import Flux2Attention, Flux2ParallelSelfAttention
from diffusers.models.transformers.transformer_ltx import LTXAttention
from diffusers.models.transformers.transformer_wan import WanAttention
else:
Expand Down Expand Up @@ -190,6 +191,10 @@ def forward(self, *args, **kwargs):
QuantModuleRegistry.register({FluxAttention: "FluxAttention"})(_QuantAttentionModuleMixin)
QuantModuleRegistry.register({WanAttention: "WanAttention"})(_QuantAttentionModuleMixin)
QuantModuleRegistry.register({LTXAttention: "LTXAttention"})(_QuantAttentionModuleMixin)
QuantModuleRegistry.register({Flux2Attention: "Flux2Attention"})(_QuantAttentionModuleMixin)
QuantModuleRegistry.register({Flux2ParallelSelfAttention: "Flux2ParallelSelfAttention"})(_QuantAttentionModuleMixin)




original_scaled_dot_product_attention = F.scaled_dot_product_attention
Expand Down