From 7294c6a79e907522707d6d69e5989b1ff1eec269 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Wed, 17 Dec 2025 19:16:49 +0300 Subject: [PATCH 01/11] Add Flux2Attention to quantization plugin Signed-off-by: Oguz Vuruskaner --- modelopt/torch/quantization/plugins/diffusers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modelopt/torch/quantization/plugins/diffusers.py b/modelopt/torch/quantization/plugins/diffusers.py index 440d190d3..e40a492fe 100644 --- a/modelopt/torch/quantization/plugins/diffusers.py +++ b/modelopt/torch/quantization/plugins/diffusers.py @@ -31,6 +31,8 @@ from diffusers.models.attention import AttentionModuleMixin from diffusers.models.attention_dispatch import AttentionBackendName, attention_backend from diffusers.models.transformers.transformer_flux import FluxAttention + from diffusers.models.transformers.transformer_flux2 import Flux2Attention + from diffusers.models.transformers.transformer_ltx import LTXAttention from diffusers.models.transformers.transformer_wan import WanAttention else: @@ -190,6 +192,8 @@ def forward(self, *args, **kwargs): QuantModuleRegistry.register({FluxAttention: "FluxAttention"})(_QuantAttentionModuleMixin) QuantModuleRegistry.register({WanAttention: "WanAttention"})(_QuantAttentionModuleMixin) QuantModuleRegistry.register({LTXAttention: "LTXAttention"})(_QuantAttentionModuleMixin) + QuantModuleRegistry.register({Flux2Attention: "Flux2Attention"})(_QuantAttentionModuleMixin) + original_scaled_dot_product_attention = F.scaled_dot_product_attention From ab959dca16788044a85216db56ef1bddb57f78e4 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Wed, 17 Dec 2025 19:36:41 +0300 Subject: [PATCH 02/11] Update diffusers.py Signed-off-by: Oguz Vuruskaner --- modelopt/torch/quantization/plugins/diffusers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/modelopt/torch/quantization/plugins/diffusers.py b/modelopt/torch/quantization/plugins/diffusers.py index e40a492fe..b96be91fa 100644 --- a/modelopt/torch/quantization/plugins/diffusers.py +++ b/modelopt/torch/quantization/plugins/diffusers.py @@ -32,7 +32,6 @@ from diffusers.models.attention_dispatch import AttentionBackendName, attention_backend from diffusers.models.transformers.transformer_flux import FluxAttention from diffusers.models.transformers.transformer_flux2 import Flux2Attention - from diffusers.models.transformers.transformer_ltx import LTXAttention from diffusers.models.transformers.transformer_wan import WanAttention else: From d99c665de7778d085d718546d57d2fd9b77afc37 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Wed, 17 Dec 2025 20:28:03 +0300 Subject: [PATCH 03/11] Add Flux2ParallelSelfAttention to quantization registry Signed-off-by: Oguz Vuruskaner --- modelopt/torch/quantization/plugins/diffusers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/quantization/plugins/diffusers.py b/modelopt/torch/quantization/plugins/diffusers.py index b96be91fa..6472c52dd 100644 --- a/modelopt/torch/quantization/plugins/diffusers.py +++ b/modelopt/torch/quantization/plugins/diffusers.py @@ -31,7 +31,7 @@ from diffusers.models.attention import AttentionModuleMixin from diffusers.models.attention_dispatch import AttentionBackendName, attention_backend from diffusers.models.transformers.transformer_flux import FluxAttention - from diffusers.models.transformers.transformer_flux2 import Flux2Attention + from diffusers.models.transformers.transformer_flux2 import Flux2Attention, Flux2ParallelSelfAttention from diffusers.models.transformers.transformer_ltx import LTXAttention from diffusers.models.transformers.transformer_wan import WanAttention else: @@ -192,6 +192,8 @@ def forward(self, *args, **kwargs): QuantModuleRegistry.register({WanAttention: "WanAttention"})(_QuantAttentionModuleMixin) QuantModuleRegistry.register({LTXAttention: "LTXAttention"})(_QuantAttentionModuleMixin) QuantModuleRegistry.register({Flux2Attention: "Flux2Attention"})(_QuantAttentionModuleMixin) + QuantModuleRegistry.register({Flux2ParallelSelfAttention: "Flux2ParallelSelfAttention"})(_QuantAttentionModuleMixin) + From ff1e8361436135bf8c75a1c9321d3034008dbee7 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Wed, 17 Dec 2025 21:55:49 +0300 Subject: [PATCH 04/11] Update quantize.py Signed-off-by: Oguz Vuruskaner --- examples/diffusers/quantization/quantize.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index df2de4fae..21aef7c40 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -48,6 +48,7 @@ from diffusers import ( DiffusionPipeline, FluxPipeline, + Flux2Pipeline, LTXConditionPipeline, LTXLatentUpsamplePipeline, StableDiffusion3Pipeline, @@ -77,6 +78,7 @@ class ModelType(str, Enum): SD35_MEDIUM = "sd3.5-medium" FLUX_DEV = "flux-dev" FLUX_SCHNELL = "flux-schnell" + FLUX_DEV_2 = "flux-2-dev" LTX_VIDEO_DEV = "ltx-video-dev" WAN22_T2V = "wan2.2-t2v-14b" @@ -138,6 +140,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: """ filter_func_map = { ModelType.FLUX_DEV: filter_func_default, + ModelType.FLUX_DEV_2: filter_func_default, ModelType.FLUX_SCHNELL: filter_func_default, ModelType.SDXL_BASE: filter_func_default, ModelType.SDXL_TURBO: filter_func_default, @@ -157,6 +160,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.SD3_MEDIUM: "stabilityai/stable-diffusion-3-medium-diffusers", ModelType.SD35_MEDIUM: "stabilityai/stable-diffusion-3.5-medium", ModelType.FLUX_DEV: "black-forest-labs/FLUX.1-dev", + ModelType.FLUX_DEV_2: "black-forest-labs/FLUX.2-dev", ModelType.FLUX_SCHNELL: "black-forest-labs/FLUX.1-schnell", ModelType.LTX_VIDEO_DEV: "Lightricks/LTX-Video-0.9.7-dev", ModelType.WAN22_T2V: "Wan-AI/Wan2.2-T2V-A14B-Diffusers", @@ -168,6 +172,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.SD3_MEDIUM: StableDiffusion3Pipeline, ModelType.SD35_MEDIUM: StableDiffusion3Pipeline, ModelType.FLUX_DEV: FluxPipeline, + ModelType.FLUX_DEV: Flux2Pipeline, ModelType.FLUX_SCHNELL: FluxPipeline, ModelType.LTX_VIDEO_DEV: LTXConditionPipeline, ModelType.WAN22_T2V: WanPipeline, @@ -221,6 +226,20 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: "max_sequence_length": 512, }, }, + ModelType.FLUX_DEV_2: { + "backbone": "transformer", + "dataset": { + "name": "nateraw/parti-prompts", + "split": "train", + "column": "Prompt", + }, + "inference_extra_args": { + "height": 1024, + "width": 1024, + "guidance_scale": 3.5, + "max_sequence_length": 512, + }, + }, ModelType.FLUX_SCHNELL: { "backbone": "transformer", "dataset": { From 4977fd5f7b5355e682c29193fb99664699e4cb8e Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Wed, 17 Dec 2025 21:57:53 +0300 Subject: [PATCH 05/11] Add 'flux-2-dev' to model and dtype mappings Signed-off-by: Oguz Vuruskaner --- examples/diffusers/quantization/diffusion_trt.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/diffusers/quantization/diffusion_trt.py b/examples/diffusers/quantization/diffusion_trt.py index 2ae32ea8e..f7166cd9e 100644 --- a/examples/diffusers/quantization/diffusion_trt.py +++ b/examples/diffusers/quantization/diffusion_trt.py @@ -52,6 +52,7 @@ "sdxl-turbo": ModelType.SDXL_TURBO, "sd3-medium": ModelType.SD3_MEDIUM, "flux-dev": ModelType.FLUX_DEV, + "flux-2-dev": ModelType.FLUX_DEV, "flux-schnell": ModelType.FLUX_SCHNELL, } @@ -60,6 +61,7 @@ "sdxl-turbo": torch.float16, "sd3-medium": torch.float16, "flux-dev": torch.bfloat16, + "flux-2-dev": torch.bfloat16, "flux-schnell": torch.bfloat16, } @@ -142,7 +144,7 @@ def main(): "--model", type=str, default="flux-dev", - choices=["sdxl-1.0", "sdxl-turbo", "sd3-medium", "flux-dev", "flux-schnell"], + choices=["sdxl-1.0", "sdxl-turbo", "sd3-medium", "flux-dev", "flux-schnell", "flux-2-dev"], ) parser.add_argument( "--override-model-path", From 2850a5e031a5c1b141ae6c841b98d91f38a641f9 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Wed, 17 Dec 2025 22:00:39 +0300 Subject: [PATCH 06/11] Update export.py Signed-off-by: Oguz Vuruskaner --- .../quantization/onnx_utils/export.py | 92 ++++++++++++++++++- 1 file changed, 89 insertions(+), 3 deletions(-) diff --git a/examples/diffusers/quantization/onnx_utils/export.py b/examples/diffusers/quantization/onnx_utils/export.py index 5a287fab5..5af58442b 100644 --- a/examples/diffusers/quantization/onnx_utils/export.py +++ b/examples/diffusers/quantization/onnx_utils/export.py @@ -93,6 +93,14 @@ "guidance": {0: "batch_size"}, "latent": {0: "batch_size"}, }, + "flux-2-dev": { + "hidden_states": {0: "batch_size", 1: "latent_dim"}, + "encoder_hidden_states": {0: "batch_size"}, + "timestep": {0: "batch_size"}, + "img_ids": {0: "latent_dim"}, + "guidance": {0: "batch_size"}, + "latent": {0: "batch_size"} + }, "flux-schnell": { "hidden_states": {0: "batch_size", 1: "latent_dim"}, "encoder_hidden_states": {0: "batch_size"}, @@ -265,6 +273,70 @@ def _gen_dummy_inp_and_dyn_shapes_flux(backbone, min_bs=1, opt_bs=1): return dummy_kwargs, dynamic_shapes +def _gen_dummy_inp_and_dyn_shapes_flux2(backbone, min_bs=1, opt_bs=1, max_bs=1): + assert isinstance(backbone, Flux2Transformer2DModel) or isinstance( + backbone._orig_mod, Flux2Transformer2DModel + ) + cfg = backbone.config + + text_maxlen = 512 # constant + + min_img_dim = 256 # 256x256 + opt_img_dim = 4096 # 1024x1024 + max_img_dim = 16384 # 2048x2048 (Safe upper bound) + + rope_dim = 4 + + dynamic_shapes = { + "hidden_states": { + "min": [min_bs, min_img_dim, cfg.in_channels], + "opt": [opt_bs, opt_img_dim, cfg.in_channels], + "max": [max_bs, max_img_dim, cfg.in_channels], + }, + "encoder_hidden_states": { + "min": [min_bs, text_maxlen, cfg.joint_attention_dim], + "opt": [opt_bs, text_maxlen, cfg.joint_attention_dim], + "max": [max_bs, text_maxlen, cfg.joint_attention_dim], + }, + "timestep": { + "min": [min_bs], + "opt": [opt_bs], + "max": [max_bs] + }, + "guidance": { + "min": [min_bs], + "opt": [opt_bs], + "max": [max_bs] + }, + "img_ids": { + "min": [min_img_dim, rope_dim], + "opt": [opt_img_dim, rope_dim], + "max": [max_img_dim, rope_dim] + }, + "txt_ids": { + "min": [text_maxlen, rope_dim], + "opt": [text_maxlen, rope_dim], + "max": [text_maxlen, rope_dim] + }, + } + + + dtype = backbone.dtype + device = backbone.device + + dummy_kwargs = { + "hidden_states": torch.randn(*dynamic_shapes["hidden_states"]["opt"], dtype=dtype, device=device), + "encoder_hidden_states": torch.randn(*dynamic_shapes["encoder_hidden_states"]["opt"], dtype=dtype, device=device), + "timestep": torch.ones(*dynamic_shapes["timestep"]["opt"], dtype=dtype, device=device), + "img_ids": torch.randn(*dynamic_shapes["img_ids"]["opt"], dtype=dtype, device=device), + "txt_ids": torch.randn(*dynamic_shapes["txt_ids"]["opt"], dtype=dtype, device=device), + "guidance": torch.full(dynamic_shapes["guidance"]["opt"], 3.5, dtype=dtype, device=device), + "return_dict": False + } + + return dummy_kwargs, dynamic_shapes + + def _gen_dummy_inp_and_dyn_shapes_ltx(backbone, min_bs=2, opt_bs=2): assert isinstance(backbone, LTXVideoTransformer3DModel) or isinstance( backbone._orig_mod, LTXVideoTransformer3DModel @@ -356,7 +428,7 @@ def _gen_dummy_inp_and_dyn_shapes_wan(backbone, min_bs=1, opt_bs=2): def update_dynamic_axes(model_id, dynamic_axes): - if model_id in ["flux-dev", "flux-schnell"]: + if model_id in ["flux-dev", "flux-schnell", "flux-2-dev"]: dynamic_axes["out.0"] = dynamic_axes.pop("latent") elif model_id in ["sdxl-1.0", "sdxl-turbo"]: dynamic_axes["added_cond_kwargs.text_embeds"] = dynamic_axes.pop("text_embeds") @@ -395,6 +467,10 @@ def generate_dummy_kwargs_and_dynamic_axes_and_shapes(model_id, backbone): dummy_kwargs, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_flux( backbone, min_bs=1, opt_bs=1 ) + elif model_id == "flux-2-dev": + dummy_kwargs, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_flux2( + backbone, min_bs=1, opt_bs=1 + ) elif model_id == "ltx-video-dev": dummy_kwargs, dynamic_shapes = _gen_dummy_inp_and_dyn_shapes_ltx( backbone, min_bs=2, opt_bs=2 @@ -421,7 +497,7 @@ def get_io_shapes(model_id, onnx_load_path, trt_dynamic_shapes): output_name = "sample" elif model_id in ["sd3.5-medium"]: output_name = "out_hidden_states" - elif model_id in ["flux-dev", "flux-schnell"]: + elif model_id in ["flux-dev", "flux-schnell", "flux-2-dev"]: output_name = "output" else: raise NotImplementedError(f"Unsupported model_id: {model_id}") @@ -430,7 +506,7 @@ def get_io_shapes(model_id, onnx_load_path, trt_dynamic_shapes): io_shapes = {output_name: trt_dynamic_shapes["minShapes"]["sample"]} elif model_id in ["sd3-medium", "sd3.5-medium"]: io_shapes = {output_name: trt_dynamic_shapes["minShapes"]["hidden_states"]} - elif model_id in ["flux-dev", "flux-schnell"]: + elif model_id in ["flux-dev", "flux-schnell", "flux-2-dev"]: io_shapes = {} return io_shapes @@ -499,6 +575,16 @@ def modelopt_export_sd(backbone, onnx_dir, model_name, precision): if model_name == "flux-dev": input_names.append("guidance") output_names = ["latent"] + elif model_name == "flux-2-dev": + input_names = [ + "hidden_states", + "encoder_hidden_states", + "timestep", + "img_ids", + "txt_ids", + "guidance", + ] + output_names = ["latent"] elif model_name in ["ltx-video-dev"]: input_names = [ "hidden_states", From a7d8c7563ba20a79e9a213667c04d6f7c09267c1 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Wed, 17 Dec 2025 22:06:50 +0300 Subject: [PATCH 07/11] Update README.md Signed-off-by: Oguz Vuruskaner --- examples/diffusers/README.md | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/examples/diffusers/README.md b/examples/diffusers/README.md index db4e24e07..7b085a884 100644 --- a/examples/diffusers/README.md +++ b/examples/diffusers/README.md @@ -71,6 +71,7 @@ mtq.quantize(model=transformer, config=quant_config, forward_func=forward_pass) | Model | fp8 | int8_sq | int4_awq | w4a8_awq1 | nvfp42 | nvfp4_svdquant3 | Cache Diffusion | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | | [FLUX](https://huggingface.co/black-forest-labs/FLUX.1-dev) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | - | +| [FLUX 2](https://huggingface.co/black-forest-labs/FLUX.2-dev) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | - | | [Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | - | @@ -103,24 +104,24 @@ bash build_sdxl_8bit_engine.sh --format {FORMAT} # FORMAT can be int8 or fp8 If you prefer to customize parameters in calibration or run other models, please follow the instructions below. -#### FLUX-Dev|SD3-Medium|SDXL|SDXL-Turbo INT8 [Script](./quantization/quantize.py) +#### FLUX-Dev|FLUX-2-Dev|SD3-Medium|SDXL|SDXL-Turbo INT8 [Script](./quantization/quantize.py) ```sh python quantize.py \ - --model {flux-dev|sdxl-1.0|sdxl-turbo|sd3-medium} \ + --model {flux-dev|fux-2-dev|sdxl-1.0|sdxl-turbo|sd3-medium} \ --format int8 --batch-size 2 \ --calib-size 32 --alpha 0.8 --n-steps 20 \ --model-dtype {Half/BFloat16} --trt-high-precision-dtype {Half|BFloat16} \ --quantized-torch-ckpt-save-path ./{MODEL_NAME}.pt --onnx-dir {ONNX_DIR} ``` -#### FLUX-Dev|SDXL|SDXL-Turbo|LTX-Video FP8/FP4 [Script](./quantization/quantize.py) +#### FLUX-Dev|FLUX-2-Dev|SDXL|SDXL-Turbo|LTX-Video FP8/FP4 [Script](./quantization/quantize.py) *In our example code, FP4 is only supported for Flux. However, you can modify our script to enable FP4 format support for your own model.* ```sh python quantize.py \ - --model {flux-dev|sdxl-1.0|sdxl-turbo|ltx-video-dev} --model-dtype {Half|BFloat16} --trt-high-precision-dtype {Half|BFloat16} \ + --model {flux-dev|flux-2-dev|sdxl-1.0|sdxl-turbo|ltx-video-dev} --model-dtype {Half|BFloat16} --trt-high-precision-dtype {Half|BFloat16} \ --format {fp8|fp4} --batch-size 2 --calib-size {128|256} --quantize-mha \ --n-steps 20 --quantized-torch-ckpt-save-path ./{MODEL_NAME}.pt --collect-method default \ --onnx-dir {ONNX_DIR} @@ -252,6 +253,14 @@ trtexec --onnx=./model.onnx --fp8 --bf16 --stronglyTyped \ --optShapes=hidden_states:1x4096x64,img_ids:4096x3,encoder_hidden_states:1x512x4096,txt_ids:512x3,timestep:1,pooled_projections:1x768,guidance:1 \ --maxShapes=hidden_states:1x4096x64,img_ids:4096x3,encoder_hidden_states:1x512x4096,txt_ids:512x3,timestep:1,pooled_projections:1x768,guidance:1 \ --saveEngine=model.plan + +# # For FLUX-2-Dev FP8 +trtexec --onnx=./model.onnx --fp8 --bf16 --stronglyTyped \ + --minShapes=hidden_states:1x4096x128,img_ids:4096x4,encoder_hidden_states:1x512x15360,txt_ids:512x4,timestep:1,guidance:1 \ + --optShapes=hidden_states:1x4096x128,img_ids:4096x4,encoder_hidden_states:1x512x15360,txt_ids:512x4,timestep:1,guidance:1 \ + --maxShapes=hidden_states:1x4096x128,img_ids:4096x4,encoder_hidden_states:1x512x15360,txt_ids:512x4,timestep:1,guidance:1 \ + --saveEngine=model.plan + ``` **Please note that `maxShapes` represents the maximum shape of the given tensor. If you want to use a larger batch size or any other dimensions, feel free to adjust the value accordingly.** @@ -293,7 +302,7 @@ Generate a quantized torch checkpoint using the [Script](./quantization/quantize ```bash python quantize.py \ - --model {sdxl-1.0|sdxl-turbo|sd3-medium|flux-dev} \ + --model {sdxl-1.0|sdxl-turbo|sd3-medium|flux-dev|flux-2-dev} \ --format fp8 \ --batch-size {1|2} \ --calib-size 128 \ From 5fcab8bb3503e4dd789921f7b5084aa6ca9cdb08 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Wed, 17 Dec 2025 22:08:57 +0300 Subject: [PATCH 08/11] Update flux-2-dev model type in diffusion_trt.py Signed-off-by: Oguz Vuruskaner --- examples/diffusers/quantization/diffusion_trt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/diffusers/quantization/diffusion_trt.py b/examples/diffusers/quantization/diffusion_trt.py index f7166cd9e..9f6fd229e 100644 --- a/examples/diffusers/quantization/diffusion_trt.py +++ b/examples/diffusers/quantization/diffusion_trt.py @@ -52,7 +52,7 @@ "sdxl-turbo": ModelType.SDXL_TURBO, "sd3-medium": ModelType.SD3_MEDIUM, "flux-dev": ModelType.FLUX_DEV, - "flux-2-dev": ModelType.FLUX_DEV, + "flux-2-dev": ModelType.FLUX_2_DEV, "flux-schnell": ModelType.FLUX_SCHNELL, } From 50fcff84c4a9ba99165097fa88a89eaffa46c9d6 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Wed, 17 Dec 2025 22:09:35 +0300 Subject: [PATCH 09/11] Rename FLUX_DEV_2 to FLUX_2_DEV in quantize.py Signed-off-by: Oguz Vuruskaner --- examples/diffusers/quantization/quantize.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/diffusers/quantization/quantize.py b/examples/diffusers/quantization/quantize.py index 21aef7c40..8bb730f04 100644 --- a/examples/diffusers/quantization/quantize.py +++ b/examples/diffusers/quantization/quantize.py @@ -78,7 +78,7 @@ class ModelType(str, Enum): SD35_MEDIUM = "sd3.5-medium" FLUX_DEV = "flux-dev" FLUX_SCHNELL = "flux-schnell" - FLUX_DEV_2 = "flux-2-dev" + FLUX_2_DEV = "flux-2-dev" LTX_VIDEO_DEV = "ltx-video-dev" WAN22_T2V = "wan2.2-t2v-14b" @@ -140,7 +140,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: """ filter_func_map = { ModelType.FLUX_DEV: filter_func_default, - ModelType.FLUX_DEV_2: filter_func_default, + ModelType.FLUX_2_DEV: filter_func_default, ModelType.FLUX_SCHNELL: filter_func_default, ModelType.SDXL_BASE: filter_func_default, ModelType.SDXL_TURBO: filter_func_default, @@ -160,7 +160,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: ModelType.SD3_MEDIUM: "stabilityai/stable-diffusion-3-medium-diffusers", ModelType.SD35_MEDIUM: "stabilityai/stable-diffusion-3.5-medium", ModelType.FLUX_DEV: "black-forest-labs/FLUX.1-dev", - ModelType.FLUX_DEV_2: "black-forest-labs/FLUX.2-dev", + ModelType.FLUX_2_DEV: "black-forest-labs/FLUX.2-dev", ModelType.FLUX_SCHNELL: "black-forest-labs/FLUX.1-schnell", ModelType.LTX_VIDEO_DEV: "Lightricks/LTX-Video-0.9.7-dev", ModelType.WAN22_T2V: "Wan-AI/Wan2.2-T2V-A14B-Diffusers", @@ -226,7 +226,7 @@ def get_model_filter_func(model_type: ModelType) -> Callable[[str], bool]: "max_sequence_length": 512, }, }, - ModelType.FLUX_DEV_2: { + ModelType.FLUX_2_DEV: { "backbone": "transformer", "dataset": { "name": "nateraw/parti-prompts", From 77dca36e117483c89c1d19ae262bd0e0956e2b3f Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Wed, 17 Dec 2025 22:14:28 +0300 Subject: [PATCH 10/11] Update CHANGELOG with new quantization features Added support for various quantization features including Flux.2-dev and Transformer Engine. Signed-off-by: Oguz Vuruskaner --- CHANGELOG.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 61c198026..815714762 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,6 +6,8 @@ NVIDIA Model Optimizer Changelog (Linux) **New Features** +- Add support for **Flux.2-dev** (`black-forest-labs/FLUX.2-dev`) quantization (INT8, FP8, NVFP4) and ONNX export. +- Add quantization support for ``Flux2Attention`` and ``Flux2ParallelSelfAttention`` layers in the Diffusers plugin. - Add support for Transformer Engine quantization for Megatron Core models. - Add support for Qwen3-Next model quantization. - Add support for dynamically linked TensorRT plugins in the ONNX quantization workflow. From 61ac2fd0316078432225e97fdfe874705cb74e34 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Fri, 19 Dec 2025 13:24:00 +0300 Subject: [PATCH 11/11] Update export.py Signed-off-by: Oguz Vuruskaner --- examples/diffusers/quantization/onnx_utils/export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/diffusers/quantization/onnx_utils/export.py b/examples/diffusers/quantization/onnx_utils/export.py index 5af58442b..f150d9618 100644 --- a/examples/diffusers/quantization/onnx_utils/export.py +++ b/examples/diffusers/quantization/onnx_utils/export.py @@ -283,7 +283,7 @@ def _gen_dummy_inp_and_dyn_shapes_flux2(backbone, min_bs=1, opt_bs=1, max_bs=1): min_img_dim = 256 # 256x256 opt_img_dim = 4096 # 1024x1024 - max_img_dim = 16384 # 2048x2048 (Safe upper bound) + max_img_dim = 16384 # 2048x2048 rope_dim = 4