feat: Add support for Flux.2-dev quantization and export #707
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Type of change: New feature
Overview:
This PR adds support for Flux.2-dev (
black-forest-labs/FLUX.2-dev) quantization and ONNX export. It enables users to quantize Flux.2 models (INT8/FP8/FP4) and export them to TensorRT by handling the architectural differences in the new model version, specifically the updated RoPE embedding dimensions and attention mechanisms.Key changes include:
Flux2AttentionandFlux2ParallelSelfAttentioninmodelopt/torch/quantization/plugins/diffusers.pyto enable quantization for the new architecture._gen_dummy_inp_and_dyn_shapes_flux2inexport.pyto correctly handle Flux.2 input shapes, specifically accounting for the RoPE dimension change (from 3 to 4) and updated ID tensor shapes.ModelType.FLUX_2_DEVand mapped it to the correctFlux2Pipelineinquantize.pyanddiffusion_trt.py.examples/diffusers/README.mdwith support status and a specifictrtexeccommand for Flux.2.Usage
You can now quantize and export Flux.2-dev using the standard quantization script:
# Example for FP8 quantization python examples/diffusers/quantization/quantize.py \ --model flux-2-dev \ --format fp8 \ --batch-size 1 \ --calib-size 128 \ --quantized-torch-ckpt-save-path ./flux2_fp8.pt \ --onnx-dir ./onnx_flux2Testing
Tested manually by:
black-forest-labs/FLUX.2-devmodel.img_idsandtxt_idswith the last dimension as 4).trtexec.Before your PR is "Ready for review"
README.md)Additional Information
This PR addresses the support for the Flux.2 architecture, which introduces
Flux2ParallelSelfAttentionand changes the positional embedding (RoPE) dimension size to 4, requiring specific handling distinct from Flux.1.