From 9e47b974e73e41e86d63253c9bc849fc7fa35629 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Wed, 17 Dec 2025 07:59:50 -0800 Subject: [PATCH 1/6] Streamline KDTrainer for FSDP2 Signed-off-by: Asha Anoosheh --- examples/llm_distill/main.py | 2 +- modelopt/torch/distill/plugins/huggingface.py | 51 +++++++++---------- 2 files changed, 25 insertions(+), 28 deletions(-) diff --git a/examples/llm_distill/main.py b/examples/llm_distill/main.py index 64430bf88..3be21f79e 100644 --- a/examples/llm_distill/main.py +++ b/examples/llm_distill/main.py @@ -153,7 +153,7 @@ def train(): # Save checkpoint logger.info("Saving checkpoint...") trainer.save_state() - trainer.save_model(trainer.args.output_dir, export_student=True) + trainer.save_model(trainer.args.output_dir) logger.info("Checkpoint saved.") diff --git a/modelopt/torch/distill/plugins/huggingface.py b/modelopt/torch/distill/plugins/huggingface.py index 465c951fc..c52ebaf08 100644 --- a/modelopt/torch/distill/plugins/huggingface.py +++ b/modelopt/torch/distill/plugins/huggingface.py @@ -15,17 +15,21 @@ """ModelOpt plugin to train HuggingFace models with knowledge distillation.""" -import torch from transformers.modeling_outputs import CausalLMOutputWithPast import modelopt.torch.distill as mtd -import modelopt.torch.opt as mto from modelopt.torch.opt.plugins import ModelOptHFTrainer class KDTrainer(ModelOptHFTrainer): """Distillation trainer for HuggingFace models.""" + def __init__(self, *args, **kwargs): + """Initialize the trainer.""" + super().__init__(*args, **kwargs) + if self.is_fsdp_enabled and not self.accelerator.is_fsdp2: + raise ValueError("FSDP1 is not supported for distillation. Use FSDP2 instead.") + def compute_loss(self, model, inputs, *args, **kwargs): """Compute loss for distillation. @@ -49,7 +53,6 @@ def save_model( self, output_dir: str | None = None, _internal_call: bool = False, - export_student: bool = False, *args, **kwargs, ): @@ -57,36 +60,30 @@ def save_model( Args: output_dir: The directory to save the model and ModelOpt states. - export_student: Whether to export the student model. - """ if output_dir is None: output_dir = self.args.output_dir + model = self.accelerator.unwrap_model(self.model) - if not _internal_call and self.is_fsdp_enabled: - with model.hide_teacher_model(enable=export_student): - state_dict = self.accelerator.get_state_dict(self.model) - modelopt_state = mto.modelopt_state(model) - if export_student: - # Need to wait, otherwise FSDP weights may be deleted before rank 0 can gather them - self.accelerator.wait_for_everyone() - model = model.export() - - if self.accelerator.is_main_process: - model.save_pretrained( - output_dir, - is_main_process=self.accelerator.is_main_process, - save_function=self.accelerator.save, - state_dict=state_dict, - ) - self.processing_class.save_pretrained(output_dir) - torch.save(modelopt_state, f"{output_dir}/modelopt_state.pth") - else: - model = model.export() if export_student else model - super().save_model(output_dir, _internal_call, *args, **kwargs) + with model.hide_teacher_model(), model.hide_loss_modules(enable=not _internal_call): + if _internal_call: + return super().save_model(output_dir, _internal_call, *args, **kwargs) + + extra_kwargs = {} + if self.is_fsdp_enabled: + extra_kwargs["save_function"] = self.accelerator.save + extra_kwargs["state_dict"] = self.accelerator.get_state_dict(self.model) + self.accelerator.wait_for_everyone() # needed to prevent hang somehow + + model.save_pretrained( + output_dir, + is_main_process=self.accelerator.is_main_process, + **extra_kwargs, + ) + self.processing_class.save_pretrained(output_dir) def train(self, *args, **kwargs): - """Train the model.""" + """Enable or disable training/evaluation mode.""" self.compute_loss_func = lambda *args, **kwargs: self.model.compute_kd_loss() return super().train(*args, **kwargs) From f2d97a10f123b42353fd9bb26cbe17caddfc5150 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Wed, 17 Dec 2025 08:00:54 -0800 Subject: [PATCH 2/6] Refactor QADTrainer too Signed-off-by: Asha Anoosheh --- examples/llm_qat/README.md | 2 +- .../llm_qat/llama_factory/llama_factory.py | 4 +--- examples/llm_qat/main.py | 3 +-- .../plugins/transformers_trainer.py | 20 +++++++------------ 4 files changed, 10 insertions(+), 19 deletions(-) diff --git a/examples/llm_qat/README.md b/examples/llm_qat/README.md index 75c2cf02b..6c4ac7e5f 100644 --- a/examples/llm_qat/README.md +++ b/examples/llm_qat/README.md @@ -147,7 +147,7 @@ trainer = QADTrainer( trainer.train() # Train the quantized model using distillation (i.e, QAD) # Save the final student model weights; An example usage -trainer.save_model(export_student=True) +trainer.save_model() ``` ### NeMo QAT/QAD Simplified Flow Example diff --git a/examples/llm_qat/llama_factory/llama_factory.py b/examples/llm_qat/llama_factory/llama_factory.py index 240afe19a..4e0d97ea2 100644 --- a/examples/llm_qat/llama_factory/llama_factory.py +++ b/examples/llm_qat/llama_factory/llama_factory.py @@ -249,11 +249,9 @@ def create_modelcard_and_push( ) -> None: original_fn(trainer, *args, **kwargs) - # export the student model for quantization aware distillation - kwargs = {"export_student": True} if hasattr(trainer, "distill_config") else {} # save the model in the output directory trainer.save_state() - trainer.save_model(output_dir=trainer.args.output_dir, **kwargs) + trainer.save_model(output_dir=trainer.args.output_dir) module.create_modelcard_and_push = create_modelcard_and_push diff --git a/examples/llm_qat/main.py b/examples/llm_qat/main.py index 30f49a6a5..2fa23d16a 100644 --- a/examples/llm_qat/main.py +++ b/examples/llm_qat/main.py @@ -270,8 +270,7 @@ def train(): if training_args.do_train or quant_args.quant_cfg is not None: print_rank_0("Saving the model...") trainer.save_state() - kwargs = {"export_student": True} if training_args.distill else {} - trainer.save_model(training_args.output_dir, **kwargs) + trainer.save_model(training_args.output_dir) if __name__ == "__main__": diff --git a/modelopt/torch/quantization/plugins/transformers_trainer.py b/modelopt/torch/quantization/plugins/transformers_trainer.py index d0ca15a6f..2f8847d1f 100644 --- a/modelopt/torch/quantization/plugins/transformers_trainer.py +++ b/modelopt/torch/quantization/plugins/transformers_trainer.py @@ -255,8 +255,7 @@ def train(self, *args, **kwargs): """Train the model.""" outputs = super().train(*args, **kwargs) print_rank_0( - "Training completed. Please save the final model using `Trainer.save_model()` " - "to preserve ModelOpt states." + "Training completed. Please save the final model using `Trainer.save_model()` to preserve ModelOpt states." ) return outputs @@ -271,8 +270,7 @@ def save_model(self, *args, **kwargs): original_type = self.accelerator.state.fsdp_plugin.state_dict_type self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") outputs = super().save_model(*args, **kwargs) - if torch.distributed.is_initialized(): - torch.distributed.barrier() + self.accelerator.wait_for_everyone() if mto.ModeloptStateManager.is_converted(self.accelerator.unwrap_model(self.model)): print_rank_0( "Model saved. To restore, call mto.enable_huggingface_checkpointing() first before loading the " @@ -414,7 +412,6 @@ def save_model( self, output_dir: str | None = None, _internal_call: bool = False, - export_student: bool = False, *args, **kwargs, ): @@ -422,15 +419,12 @@ def save_model( Args: output_dir: The directory to save the model and ModelOpt states. - export_student: Whether to export the student model. """ if self.accelerator.is_fsdp2 and "SHARDED_STATE_DICT" in str( self.accelerator.state.fsdp_plugin.state_dict_type ): - if export_student: - model = self.accelerator.unwrap_model(self.model) - model = model.export() - return QATTrainer.save_model(self, output_dir, _internal_call, *args, **kwargs) - return KDTrainer.save_model( - self, output_dir, _internal_call, export_student, *args, **kwargs - ) + model = self.accelerator.unwrap_model(self.model) + with model.hide_teacher_model(), model.hide_loss_modules(enable=not _internal_call): + return QATTrainer.save_model(self, output_dir, _internal_call, *args, **kwargs) + else: + return KDTrainer.save_model(self, output_dir, _internal_call, *args, **kwargs) From 62d524182df8443231fe21f0d33f77c40be6e603 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Thu, 18 Dec 2025 06:00:38 -0800 Subject: [PATCH 3/6] No need for teacher_factory Signed-off-by: Asha Anoosheh --- .../bert_prune_distill_quantize.py | 9 +++---- examples/llm_qat/README.md | 9 +------ .../llm_qat/llama_factory/llama_factory.py | 16 ++++--------- examples/llm_qat/main.py | 24 ++++++------------- 4 files changed, 15 insertions(+), 43 deletions(-) diff --git a/examples/chained_optimizations/bert_prune_distill_quantize.py b/examples/chained_optimizations/bert_prune_distill_quantize.py index 6bfe5dd4e..bd22e3074 100644 --- a/examples/chained_optimizations/bert_prune_distill_quantize.py +++ b/examples/chained_optimizations/bert_prune_distill_quantize.py @@ -861,11 +861,6 @@ def postprocess_qa_predictions( return eval_metric -# Model Optimizer: Define a teacher factory for initializing the distillation model -def teacher_factory(model_name_or_path): - return AutoModelForQuestionAnswering.from_pretrained(model_name_or_path) - - # Model Optimizer: Define a custom distillation loss function that uses start and end logits class StartEndLogitsDistillationLoss(mtd.LogitsDistillationLoss): def forward(self, outputs_s, outputs_t): @@ -1199,7 +1194,9 @@ def forward_loop(model): logger.info(f"Using distillation with teacher {args.model_name_or_path}") kd_config = { - "teacher_model": (teacher_factory, (args.model_name_or_path,), {}), + "teacher_model": AutoModelForQuestionAnswering.from_pretrained( + args.model_name_or_path, + ), "criterion": StartEndLogitsDistillationLoss(args.temperature), } model = mtd.convert(model, mode=[("kd_loss", kd_config)]) diff --git a/examples/llm_qat/README.md b/examples/llm_qat/README.md index 6c4ac7e5f..f5dbdaaac 100644 --- a/examples/llm_qat/README.md +++ b/examples/llm_qat/README.md @@ -123,14 +123,7 @@ from modelopt.torch.quantization.plugins.transformers_trainer import QADTrainer # [Not shown] load model, tokenizer, data loaders etc # Create the distillation config distill_config = { - "teacher_model": ( - _teacher_factory, - ( - model_args.teacher_model, - training_args.cache_dir, - ), - {}, - ), + "teacher_model": teacher_model, "criterion": LMLogitsLoss(), "expose_minimal_state_dict": False, } diff --git a/examples/llm_qat/llama_factory/llama_factory.py b/examples/llm_qat/llama_factory/llama_factory.py index 4e0d97ea2..80a3badb2 100644 --- a/examples/llm_qat/llama_factory/llama_factory.py +++ b/examples/llm_qat/llama_factory/llama_factory.py @@ -73,13 +73,6 @@ def _get_init_kwargs(model_args: ModelArguments) -> dict[str, Any]: mto.enable_huggingface_checkpointing() -def _teacher_factory(model_name_or_path): - """Function to create a teacher model.""" - return transformers.AutoModelForCausalLM.from_pretrained( - model_name_or_path, - ) - - def parse_args(): """Parse configuration file and extract ModelOpt quantization/distillation arguments. @@ -221,12 +214,11 @@ def __init__(self, *args, **kwargs): # Initialize parent classes modelopt_trainer_args = {"quant_args": quant_args} if distill_args and distill_args.distill: + teacher_model = transformers.AutoModelForCausalLM.from_pretrained( + distill_args.teacher_model, + ) distill_config = { - "teacher_model": ( - _teacher_factory, - (distill_args.teacher_model,), - {}, - ), + "teacher_model": teacher_model, "criterion": LMLogitsLoss(), "expose_minimal_state_dict": False, # FSDP requires this to be False } diff --git a/examples/llm_qat/main.py b/examples/llm_qat/main.py index 2fa23d16a..ab1951714 100644 --- a/examples/llm_qat/main.py +++ b/examples/llm_qat/main.py @@ -152,15 +152,6 @@ class QuantizationArguments: ) -def _teacher_factory(model_name_or_path, cache_dir=None): - """Function to create a teacher model.""" - return transformers.AutoModelForCausalLM.from_pretrained( - model_name_or_path, - cache_dir=cache_dir, - torch_dtype=torch.bfloat16, - ) - - def train(): parser = transformers.HfArgumentParser( (ModelArguments, TrainingArguments, DataArguments, QuantizationArguments) @@ -228,15 +219,14 @@ def train(): distill_kwargs = {} if training_args.distill: assert model_args.teacher_model is not None, "Teacher model is required for distillation." + + teacher_model = transformers.AutoModelForCausalLM.from_pretrained( + model_args.teacher_model, + cache_dir=training_args.cache_dir, + torch_dtype=torch.bfloat16, + ) distill_config = { - "teacher_model": ( - _teacher_factory, - ( - model_args.teacher_model, - training_args.cache_dir, - ), - {}, - ), + "teacher_model": teacher_model, "criterion": LMLogitsLoss(), "expose_minimal_state_dict": False, # FSDP forces us to disable this } From 8b83f05d0025930b8cb0903ba7cca6700f575805 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Thu, 18 Dec 2025 07:06:56 -0800 Subject: [PATCH 4/6] Enforce FSDP2 for KD/QAD Signed-off-by: Asha Anoosheh --- examples/diffusers/README.md | 1 - examples/llm_qat/README.md | 1 - examples/llm_qat/launch.sh | 9 ++++++--- examples/llm_qat/llama_factory/launch_llamafactory.sh | 4 ++++ examples/llm_qat/llama_factory/llama_factory.py | 1 - examples/llm_qat/main.py | 1 - tests/unit/torch/opt/plugins/test_hf_patching.py | 1 - 7 files changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/diffusers/README.md b/examples/diffusers/README.md index db4e24e07..51ba929b0 100644 --- a/examples/diffusers/README.md +++ b/examples/diffusers/README.md @@ -201,7 +201,6 @@ def forward(input): + "teacher_model": teacher_model, + "criterion": distill_config["criterion"], + "loss_balancer": distill_config["loss_balancer"], -+ "expose_minimal_state_dict": False, + } + transformer = mtd.convert(transformer, mode=[("kd_loss", kd_config)]) diff --git a/examples/llm_qat/README.md b/examples/llm_qat/README.md index f5dbdaaac..2aef3bedb 100644 --- a/examples/llm_qat/README.md +++ b/examples/llm_qat/README.md @@ -125,7 +125,6 @@ from modelopt.torch.quantization.plugins.transformers_trainer import QADTrainer distill_config = { "teacher_model": teacher_model, "criterion": LMLogitsLoss(), - "expose_minimal_state_dict": False, } trainer = QADTrainer( diff --git a/examples/llm_qat/launch.sh b/examples/llm_qat/launch.sh index 5d9fc3a7b..08b4a96e4 100755 --- a/examples/llm_qat/launch.sh +++ b/examples/llm_qat/launch.sh @@ -139,8 +139,11 @@ esac DISTILLATION_ARGS="" if [[ "${DISTILL,,}" == "true" ]]; then DISTILLATION_ARGS="--distill $DISTILL --teacher_model $TEACHER_MODEL" - # Distillation does not work with memory efficient loading for FSDP - if [[ "${BACKEND,,}" == "fsdp1" || "${BACKEND,,}" == "fsdp2" ]]; then + if [[ "${BACKEND,,}" == "fsdp1"]]; then + echo "Error: Distillation does not support FSDP1. Use FSDP2 instead." + exit 1 + elif [[ "${BACKEND,,}" == "fsdp2" ]]; then + # Distillation does not work with memory efficient loading for FSDP FSDP_ARGS="$FSDP_ARGS --fsdp_cpu_ram_efficient_loading False" fi fi @@ -180,4 +183,4 @@ CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \ start_time=$(date +%s) sh -c "$CMD" -echo "Total time taken: $(( $(date +%s) - $start_time )) seconds" \ No newline at end of file +echo "Total time taken: $(( $(date +%s) - $start_time )) seconds" diff --git a/examples/llm_qat/llama_factory/launch_llamafactory.sh b/examples/llm_qat/llama_factory/launch_llamafactory.sh index 23e06f26a..03effc89a 100644 --- a/examples/llm_qat/llama_factory/launch_llamafactory.sh +++ b/examples/llm_qat/llama_factory/launch_llamafactory.sh @@ -248,6 +248,10 @@ else # Add teacher model specific FSDP args if needed if [[ "${HAS_TEACHER_MODEL,,}" == "true" ]]; then + if [[ "${USE_FSDP2,,}" != "true" ]]; then + echo "Error: Quantization aware distillation is only supported with FSDP2." + exit 1 + fi FSDP_ARGS="$FSDP_ARGS --fsdp_cpu_ram_efficient_loading False" fi diff --git a/examples/llm_qat/llama_factory/llama_factory.py b/examples/llm_qat/llama_factory/llama_factory.py index 80a3badb2..121c3b0f3 100644 --- a/examples/llm_qat/llama_factory/llama_factory.py +++ b/examples/llm_qat/llama_factory/llama_factory.py @@ -220,7 +220,6 @@ def __init__(self, *args, **kwargs): distill_config = { "teacher_model": teacher_model, "criterion": LMLogitsLoss(), - "expose_minimal_state_dict": False, # FSDP requires this to be False } modelopt_trainer_args["distill_config"] = distill_config super().__init__(*args, **modelopt_trainer_args, **kwargs) diff --git a/examples/llm_qat/main.py b/examples/llm_qat/main.py index ab1951714..eb4ff74cd 100644 --- a/examples/llm_qat/main.py +++ b/examples/llm_qat/main.py @@ -228,7 +228,6 @@ def train(): distill_config = { "teacher_model": teacher_model, "criterion": LMLogitsLoss(), - "expose_minimal_state_dict": False, # FSDP forces us to disable this } distill_kwargs["distill_config"] = distill_config trainer_cls = QADTrainer if training_args.distill else QATTrainer diff --git a/tests/unit/torch/opt/plugins/test_hf_patching.py b/tests/unit/torch/opt/plugins/test_hf_patching.py index 0b6427d18..8a44ad23c 100644 --- a/tests/unit/torch/opt/plugins/test_hf_patching.py +++ b/tests/unit/torch/opt/plugins/test_hf_patching.py @@ -45,7 +45,6 @@ def test_nested_model_save_restore(tmp_path, model_cls, teacher_model_type): kd_config = { "teacher_model": teacher_model, "criterion": mtd.LogitsDistillationLoss(), - "expose_minimal_state_dict": False, } model = mtd.convert(model_ref, mode=[("kd_loss", kd_config)]) model.save_pretrained(tiny_llama_dir / "modelopt_model") From 65727c72367215b2cb45f0aeeced955a88e3dc29 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Thu, 18 Dec 2025 07:36:57 -0800 Subject: [PATCH 5/6] README and extra changes Signed-off-by: Asha Anoosheh --- examples/llm_qat/README.md | 7 ++++--- examples/llm_qat/launch.sh | 2 +- modelopt/torch/distill/plugins/huggingface.py | 3 +++ .../torch/quantization/plugins/transformers_trainer.py | 10 +++------- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/llm_qat/README.md b/examples/llm_qat/README.md index 2aef3bedb..250035513 100644 --- a/examples/llm_qat/README.md +++ b/examples/llm_qat/README.md @@ -237,7 +237,7 @@ You could also add your own customized quantization format to `CUSTOM_QUANT_CFG` > **_NOTE:_** `launch.sh` defaults to use `LlamaDecoderLayer` as the transformer layer class. If your model uses a different class, you need to pass `--fsdp_transformer_layer_cls_to_wrap ` to the `launch.sh` script. For example, for `Qwen/Qwen3-8B`, specify `--fsdp_transformer_layer_cls_to_wrap Qwen3DecoderLayer` as an additional argument. -> **_NOTE:_** The script defaults to using FSDP1. To use FSDP2, pass "--use_fsdp2 True" to the `launch.sh` script. Note that FSDP2 is less stable than FSDP1 currently. Use it with caution. +> **_NOTE:_** The script defaults to using FSDP1. To use FSDP2, pass "--backend=fsdp2" to the `launch.sh` script. Note that FSDP2 is less stable than FSDP1 currently. Use it with caution. ### Results @@ -268,10 +268,11 @@ To perform QAD with logits loss, run: --quant_cfg NVFP4_DEFAULT_CFG \ --do_train True \ --output_dir llama-qad \ - --distill True + --distill True \ + --backend fsdp2 ``` -> **_NOTE:_** QAD currently requires quantization to be applied before the FSDP wrapper. Training is not supported for models that exceed single GPU memory capacity. +> **_NOTE:_** QAD doesn't support original FSDP - only FSDP2. It also requires quantization to be applied before the FSDP wrapper. ## Testing QAT model with LLM benchmarks for accuracy evaluation diff --git a/examples/llm_qat/launch.sh b/examples/llm_qat/launch.sh index 08b4a96e4..c9fd33095 100755 --- a/examples/llm_qat/launch.sh +++ b/examples/llm_qat/launch.sh @@ -139,7 +139,7 @@ esac DISTILLATION_ARGS="" if [[ "${DISTILL,,}" == "true" ]]; then DISTILLATION_ARGS="--distill $DISTILL --teacher_model $TEACHER_MODEL" - if [[ "${BACKEND,,}" == "fsdp1"]]; then + if [[ "${BACKEND,,}" == "fsdp1" ]]; then echo "Error: Distillation does not support FSDP1. Use FSDP2 instead." exit 1 elif [[ "${BACKEND,,}" == "fsdp2" ]]; then diff --git a/modelopt/torch/distill/plugins/huggingface.py b/modelopt/torch/distill/plugins/huggingface.py index c52ebaf08..cb0c2d6bd 100644 --- a/modelopt/torch/distill/plugins/huggingface.py +++ b/modelopt/torch/distill/plugins/huggingface.py @@ -58,6 +58,9 @@ def save_model( ): """Dumps model and ModelOpt states to disk. + Note: Will save pretrained model in safetensors format if called manually, otherwise will + save in training checkpointformat (when called internally by transformers Trainer). + Args: output_dir: The directory to save the model and ModelOpt states. """ diff --git a/modelopt/torch/quantization/plugins/transformers_trainer.py b/modelopt/torch/quantization/plugins/transformers_trainer.py index 2f8847d1f..76e16094b 100644 --- a/modelopt/torch/quantization/plugins/transformers_trainer.py +++ b/modelopt/torch/quantization/plugins/transformers_trainer.py @@ -24,10 +24,9 @@ import torch from tqdm import tqdm +import modelopt.torch.distill as mtd import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq -from modelopt.torch.distill import KDLossConfig -from modelopt.torch.distill.mode import _convert_for_kd from modelopt.torch.distill.plugins.huggingface import KDTrainer from modelopt.torch.opt.conversion import restore_from_modelopt_state from modelopt.torch.opt.plugins import ModelOptHFTrainer @@ -386,9 +385,7 @@ def __init__( super().__init__(*args, **kwargs) - # Note: QAD doesn't work with FSDP wrapped model. We quantize model before the wrapper. - # The drawback is that we can't train a model that is bigger than a single GPU memory. - # And memory efficient loading doesn't work. + # Note: FSDP memory efficient loading doesn't work. self.model.cuda() if self.quant_cfg is not None and not is_quantized(self.model): self._quantize_model() @@ -399,8 +396,7 @@ def __init__( def _convert_to_distillation_model(self): """Convert the model to a distillation model.""" - # We don't need any save/restore feature of the distallation mode, so we skip it here. - _convert_for_kd(self.model, KDLossConfig(**self.distill_config)) + mtd.convert(self.model, mode=[("kd_loss", self.distill_config)]) print_rank_0("Distillation model created.") def train(self, *args, **kwargs): From 190e4d2b3a18a0e4948bfff49f60561463291f53 Mon Sep 17 00:00:00 2001 From: Asha Anoosheh Date: Mon, 22 Dec 2025 06:41:46 -0800 Subject: [PATCH 6/6] Remove outdated FSDP2 notes Signed-off-by: Asha Anoosheh --- examples/llm_qat/README.md | 2 +- examples/llm_qat/launch.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/llm_qat/README.md b/examples/llm_qat/README.md index 250035513..fed12d1e6 100644 --- a/examples/llm_qat/README.md +++ b/examples/llm_qat/README.md @@ -237,7 +237,7 @@ You could also add your own customized quantization format to `CUSTOM_QUANT_CFG` > **_NOTE:_** `launch.sh` defaults to use `LlamaDecoderLayer` as the transformer layer class. If your model uses a different class, you need to pass `--fsdp_transformer_layer_cls_to_wrap ` to the `launch.sh` script. For example, for `Qwen/Qwen3-8B`, specify `--fsdp_transformer_layer_cls_to_wrap Qwen3DecoderLayer` as an additional argument. -> **_NOTE:_** The script defaults to using FSDP1. To use FSDP2, pass "--backend=fsdp2" to the `launch.sh` script. Note that FSDP2 is less stable than FSDP1 currently. Use it with caution. +> **_NOTE:_** The script defaults to using FSDP1. To use FSDP2, pass "--backend=fsdp2" to the `launch.sh` script. ### Results diff --git a/examples/llm_qat/launch.sh b/examples/llm_qat/launch.sh index c9fd33095..1500e98b7 100755 --- a/examples/llm_qat/launch.sh +++ b/examples/llm_qat/launch.sh @@ -115,7 +115,7 @@ case "${BACKEND,,}" in FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP" ;; "fsdp2") - echo "Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers." + echo "Using FSDP2 instead of FSDP1." CONFIG_FILE="fsdp2.yaml" FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP" ;;