Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -861,11 +861,6 @@ def postprocess_qa_predictions(
return eval_metric


# Model Optimizer: Define a teacher factory for initializing the distillation model
def teacher_factory(model_name_or_path):
return AutoModelForQuestionAnswering.from_pretrained(model_name_or_path)


# Model Optimizer: Define a custom distillation loss function that uses start and end logits
class StartEndLogitsDistillationLoss(mtd.LogitsDistillationLoss):
def forward(self, outputs_s, outputs_t):
Expand Down Expand Up @@ -1199,7 +1194,9 @@ def forward_loop(model):
logger.info(f"Using distillation with teacher {args.model_name_or_path}")

kd_config = {
"teacher_model": (teacher_factory, (args.model_name_or_path,), {}),
"teacher_model": AutoModelForQuestionAnswering.from_pretrained(
args.model_name_or_path,
),
"criterion": StartEndLogitsDistillationLoss(args.temperature),
}
model = mtd.convert(model, mode=[("kd_loss", kd_config)])
Expand Down
1 change: 0 additions & 1 deletion examples/diffusers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ def forward(input):
+ "teacher_model": teacher_model,
+ "criterion": distill_config["criterion"],
+ "loss_balancer": distill_config["loss_balancer"],
+ "expose_minimal_state_dict": False,
+ }
+ transformer = mtd.convert(transformer, mode=[("kd_loss", kd_config)])

Expand Down
2 changes: 1 addition & 1 deletion examples/llm_distill/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def train():
# Save checkpoint
logger.info("Saving checkpoint...")
trainer.save_state()
trainer.save_model(trainer.args.output_dir, export_student=True)
trainer.save_model(trainer.args.output_dir)
logger.info("Checkpoint saved.")


Expand Down
19 changes: 6 additions & 13 deletions examples/llm_qat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,8 @@ from modelopt.torch.quantization.plugins.transformers_trainer import QADTrainer
# [Not shown] load model, tokenizer, data loaders etc
# Create the distillation config
distill_config = {
"teacher_model": (
_teacher_factory,
(
model_args.teacher_model,
training_args.cache_dir,
),
{},
),
"teacher_model": teacher_model,
"criterion": LMLogitsLoss(),
"expose_minimal_state_dict": False,
}

trainer = QADTrainer(
Expand All @@ -147,7 +139,7 @@ trainer = QADTrainer(
trainer.train() # Train the quantized model using distillation (i.e, QAD)

# Save the final student model weights; An example usage
trainer.save_model(export_student=True)
trainer.save_model()
```

### NeMo QAT/QAD Simplified Flow Example
Expand Down Expand Up @@ -245,7 +237,7 @@ You could also add your own customized quantization format to `CUSTOM_QUANT_CFG`

> **_NOTE:_** `launch.sh` defaults to use `LlamaDecoderLayer` as the transformer layer class. If your model uses a different class, you need to pass `--fsdp_transformer_layer_cls_to_wrap <your_layer_class>` to the `launch.sh` script. For example, for `Qwen/Qwen3-8B`, specify `--fsdp_transformer_layer_cls_to_wrap Qwen3DecoderLayer` as an additional argument.

> **_NOTE:_** The script defaults to using FSDP1. To use FSDP2, pass "--use_fsdp2 True" to the `launch.sh` script. Note that FSDP2 is less stable than FSDP1 currently. Use it with caution.
> **_NOTE:_** The script defaults to using FSDP1. To use FSDP2, pass "--backend=fsdp2" to the `launch.sh` script.

### Results

Expand Down Expand Up @@ -276,10 +268,11 @@ To perform QAD with logits loss, run:
--quant_cfg NVFP4_DEFAULT_CFG \
--do_train True \
--output_dir llama-qad \
--distill True
--distill True \
--backend fsdp2
```

> **_NOTE:_** QAD currently requires quantization to be applied before the FSDP wrapper. Training is not supported for models that exceed single GPU memory capacity.
> **_NOTE:_** QAD doesn't support original FSDP - only FSDP2. It also requires quantization to be applied before the FSDP wrapper.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
> **_NOTE:_** QAD doesn't support original FSDP - only FSDP2. It also requires quantization to be applied before the FSDP wrapper.
> **_NOTE:_** QAD doesn't support FSDP1 (https://docs.pytorch.org/docs/stable/fsdp.html) - only FSDP2. It also requires quantization to be applied before the FSDP wrapper.


## Testing QAT model with LLM benchmarks for accuracy evaluation

Expand Down
11 changes: 7 additions & 4 deletions examples/llm_qat/launch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ case "${BACKEND,,}" in
FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP"
;;
"fsdp2")
echo "Using FSDP2 instead of FSDP1. FSDP2 is not mature yet! Please use it with latest torch and transformers."
echo "Using FSDP2 instead of FSDP1."
CONFIG_FILE="fsdp2.yaml"
FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap $FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP"
;;
Expand All @@ -139,8 +139,11 @@ esac
DISTILLATION_ARGS=""
if [[ "${DISTILL,,}" == "true" ]]; then
DISTILLATION_ARGS="--distill $DISTILL --teacher_model $TEACHER_MODEL"
# Distillation does not work with memory efficient loading for FSDP
if [[ "${BACKEND,,}" == "fsdp1" || "${BACKEND,,}" == "fsdp2" ]]; then
if [[ "${BACKEND,,}" == "fsdp1" ]]; then
echo "Error: Distillation does not support FSDP1. Use FSDP2 instead."
exit 1
elif [[ "${BACKEND,,}" == "fsdp2" ]]; then
# Distillation does not work with memory efficient loading for FSDP
FSDP_ARGS="$FSDP_ARGS --fsdp_cpu_ram_efficient_loading False"
fi
fi
Expand Down Expand Up @@ -180,4 +183,4 @@ CMD="accelerate launch --config-file accelerate_config/$CONFIG_FILE $FSDP_ARGS \

start_time=$(date +%s)
sh -c "$CMD"
echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"
echo "Total time taken: $(( $(date +%s) - $start_time )) seconds"
4 changes: 4 additions & 0 deletions examples/llm_qat/llama_factory/launch_llamafactory.sh
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ else

# Add teacher model specific FSDP args if needed
if [[ "${HAS_TEACHER_MODEL,,}" == "true" ]]; then
if [[ "${USE_FSDP2,,}" != "true" ]]; then
echo "Error: Quantization aware distillation is only supported with FSDP2."
exit 1
fi
FSDP_ARGS="$FSDP_ARGS --fsdp_cpu_ram_efficient_loading False"
fi

Expand Down
21 changes: 5 additions & 16 deletions examples/llm_qat/llama_factory/llama_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,6 @@ def _get_init_kwargs(model_args: ModelArguments) -> dict[str, Any]:
mto.enable_huggingface_checkpointing()


def _teacher_factory(model_name_or_path):
"""Function to create a teacher model."""
return transformers.AutoModelForCausalLM.from_pretrained(
model_name_or_path,
)


def parse_args():
"""Parse configuration file and extract ModelOpt quantization/distillation arguments.

Expand Down Expand Up @@ -221,14 +214,12 @@ def __init__(self, *args, **kwargs):
# Initialize parent classes
modelopt_trainer_args = {"quant_args": quant_args}
if distill_args and distill_args.distill:
teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
distill_args.teacher_model,
)
distill_config = {
"teacher_model": (
_teacher_factory,
(distill_args.teacher_model,),
{},
),
"teacher_model": teacher_model,
"criterion": LMLogitsLoss(),
"expose_minimal_state_dict": False, # FSDP requires this to be False
}
modelopt_trainer_args["distill_config"] = distill_config
super().__init__(*args, **modelopt_trainer_args, **kwargs)
Expand All @@ -249,11 +240,9 @@ def create_modelcard_and_push(
) -> None:
original_fn(trainer, *args, **kwargs)

# export the student model for quantization aware distillation
kwargs = {"export_student": True} if hasattr(trainer, "distill_config") else {}
# save the model in the output directory
trainer.save_state()
trainer.save_model(output_dir=trainer.args.output_dir, **kwargs)
trainer.save_model(output_dir=trainer.args.output_dir)

module.create_modelcard_and_push = create_modelcard_and_push

Expand Down
28 changes: 8 additions & 20 deletions examples/llm_qat/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,6 @@ class QuantizationArguments:
)


def _teacher_factory(model_name_or_path, cache_dir=None):
"""Function to create a teacher model."""
return transformers.AutoModelForCausalLM.from_pretrained(
model_name_or_path,
cache_dir=cache_dir,
torch_dtype=torch.bfloat16,
)


def train():
parser = transformers.HfArgumentParser(
(ModelArguments, TrainingArguments, DataArguments, QuantizationArguments)
Expand Down Expand Up @@ -228,17 +219,15 @@ def train():
distill_kwargs = {}
if training_args.distill:
assert model_args.teacher_model is not None, "Teacher model is required for distillation."

teacher_model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.teacher_model,
cache_dir=training_args.cache_dir,
torch_dtype=torch.bfloat16,
)
distill_config = {
"teacher_model": (
_teacher_factory,
(
model_args.teacher_model,
training_args.cache_dir,
),
{},
),
"teacher_model": teacher_model,
"criterion": LMLogitsLoss(),
"expose_minimal_state_dict": False, # FSDP forces us to disable this
}
distill_kwargs["distill_config"] = distill_config
trainer_cls = QADTrainer if training_args.distill else QATTrainer
Expand Down Expand Up @@ -270,8 +259,7 @@ def train():
if training_args.do_train or quant_args.quant_cfg is not None:
print_rank_0("Saving the model...")
trainer.save_state()
kwargs = {"export_student": True} if training_args.distill else {}
trainer.save_model(training_args.output_dir, **kwargs)
trainer.save_model(training_args.output_dir)


if __name__ == "__main__":
Expand Down
54 changes: 27 additions & 27 deletions modelopt/torch/distill/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,21 @@

"""ModelOpt plugin to train HuggingFace models with knowledge distillation."""

import torch
from transformers.modeling_outputs import CausalLMOutputWithPast

import modelopt.torch.distill as mtd
import modelopt.torch.opt as mto
from modelopt.torch.opt.plugins import ModelOptHFTrainer


class KDTrainer(ModelOptHFTrainer):
"""Distillation trainer for HuggingFace models."""

def __init__(self, *args, **kwargs):
"""Initialize the trainer."""
super().__init__(*args, **kwargs)
if self.is_fsdp_enabled and not self.accelerator.is_fsdp2:
raise ValueError("FSDP1 is not supported for distillation. Use FSDP2 instead.")

def compute_loss(self, model, inputs, *args, **kwargs):
"""Compute loss for distillation.

Expand All @@ -49,44 +53,40 @@ def save_model(
self,
output_dir: str | None = None,
_internal_call: bool = False,
export_student: bool = False,
*args,
**kwargs,
):
"""Dumps model and ModelOpt states to disk.

Note: Will save pretrained model in safetensors format if called manually, otherwise will
save in training checkpointformat (when called internally by transformers Trainer).

Args:
output_dir: The directory to save the model and ModelOpt states.
export_student: Whether to export the student model.

"""
if output_dir is None:
output_dir = self.args.output_dir

model = self.accelerator.unwrap_model(self.model)
if not _internal_call and self.is_fsdp_enabled:
with model.hide_teacher_model(enable=export_student):
state_dict = self.accelerator.get_state_dict(self.model)
modelopt_state = mto.modelopt_state(model)
if export_student:
# Need to wait, otherwise FSDP weights may be deleted before rank 0 can gather them
self.accelerator.wait_for_everyone()
model = model.export()

if self.accelerator.is_main_process:
model.save_pretrained(
output_dir,
is_main_process=self.accelerator.is_main_process,
save_function=self.accelerator.save,
state_dict=state_dict,
)
self.processing_class.save_pretrained(output_dir)
torch.save(modelopt_state, f"{output_dir}/modelopt_state.pth")
else:
model = model.export() if export_student else model
super().save_model(output_dir, _internal_call, *args, **kwargs)
with model.hide_teacher_model(), model.hide_loss_modules(enable=not _internal_call):
if _internal_call:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to save teacher if _internal_call = True? Can we avoid saving teacher irrespective of _internal_call value? This will accelerate checkpoint save/load (you would likely need to hide teach during the final checkpoint load as well)

return super().save_model(output_dir, _internal_call, *args, **kwargs)

extra_kwargs = {}
if self.is_fsdp_enabled:
extra_kwargs["save_function"] = self.accelerator.save
extra_kwargs["state_dict"] = self.accelerator.get_state_dict(self.model)
self.accelerator.wait_for_everyone() # needed to prevent hang somehow

model.save_pretrained(
output_dir,
is_main_process=self.accelerator.is_main_process,
**extra_kwargs,
)
self.processing_class.save_pretrained(output_dir)

def train(self, *args, **kwargs):
"""Train the model."""
"""Enable or disable training/evaluation mode."""
self.compute_loss_func = lambda *args, **kwargs: self.model.compute_kd_loss()
return super().train(*args, **kwargs)

Expand Down
30 changes: 10 additions & 20 deletions modelopt/torch/quantization/plugins/transformers_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@
import torch
from tqdm import tqdm

import modelopt.torch.distill as mtd
import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
from modelopt.torch.distill import KDLossConfig
from modelopt.torch.distill.mode import _convert_for_kd
from modelopt.torch.distill.plugins.huggingface import KDTrainer
from modelopt.torch.opt.conversion import restore_from_modelopt_state
from modelopt.torch.opt.plugins import ModelOptHFTrainer
Expand Down Expand Up @@ -255,8 +254,7 @@ def train(self, *args, **kwargs):
"""Train the model."""
outputs = super().train(*args, **kwargs)
print_rank_0(
"Training completed. Please save the final model using `Trainer.save_model()` "
"to preserve ModelOpt states."
"Training completed. Please save the final model using `Trainer.save_model()` to preserve ModelOpt states."
)
return outputs

Expand All @@ -271,8 +269,7 @@ def save_model(self, *args, **kwargs):
original_type = self.accelerator.state.fsdp_plugin.state_dict_type
self.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
outputs = super().save_model(*args, **kwargs)
if torch.distributed.is_initialized():
torch.distributed.barrier()
self.accelerator.wait_for_everyone()
if mto.ModeloptStateManager.is_converted(self.accelerator.unwrap_model(self.model)):
print_rank_0(
"Model saved. To restore, call mto.enable_huggingface_checkpointing() first before loading the "
Expand Down Expand Up @@ -388,9 +385,7 @@ def __init__(

super().__init__(*args, **kwargs)

# Note: QAD doesn't work with FSDP wrapped model. We quantize model before the wrapper.
# The drawback is that we can't train a model that is bigger than a single GPU memory.
# And memory efficient loading doesn't work.
# Note: FSDP memory efficient loading doesn't work.
self.model.cuda()
if self.quant_cfg is not None and not is_quantized(self.model):
self._quantize_model()
Expand All @@ -401,8 +396,7 @@ def __init__(

def _convert_to_distillation_model(self):
"""Convert the model to a distillation model."""
# We don't need any save/restore feature of the distallation mode, so we skip it here.
_convert_for_kd(self.model, KDLossConfig(**self.distill_config))
mtd.convert(self.model, mode=[("kd_loss", self.distill_config)])
print_rank_0("Distillation model created.")

def train(self, *args, **kwargs):
Expand All @@ -414,23 +408,19 @@ def save_model(
self,
output_dir: str | None = None,
_internal_call: bool = False,
export_student: bool = False,
*args,
**kwargs,
):
"""Dumps model to disk without teacher model and loss modules.

Args:
output_dir: The directory to save the model and ModelOpt states.
export_student: Whether to export the student model.
"""
if self.accelerator.is_fsdp2 and "SHARDED_STATE_DICT" in str(
self.accelerator.state.fsdp_plugin.state_dict_type
):
if export_student:
model = self.accelerator.unwrap_model(self.model)
model = model.export()
return QATTrainer.save_model(self, output_dir, _internal_call, *args, **kwargs)
return KDTrainer.save_model(
self, output_dir, _internal_call, export_student, *args, **kwargs
)
model = self.accelerator.unwrap_model(self.model)
with model.hide_teacher_model(), model.hide_loss_modules(enable=not _internal_call):
return QATTrainer.save_model(self, output_dir, _internal_call, *args, **kwargs)
else:
return KDTrainer.save_model(self, output_dir, _internal_call, *args, **kwargs)
Loading