Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions examples/models/qwen3_5_moe/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,12 @@ def load_and_quantize(args): # noqa: C901
# CUDA: quantize experts with packed INT4 for Triton kernel
if args.qlinear or args.qembedding:
_quantize(model, config, args)
# Unwrap torchao tensor subclasses (e.g. AffineQuantizedTensor)
# into parametrized plain tensors so torch.export can handle them.
# Mirrors executorch/export/stages.py SourceTransformStage.
from torchao.utils import unwrap_tensor_subclass

unwrap_tensor_subclass(model)
else:
model.to(dtype=torch.bfloat16)

Expand Down Expand Up @@ -290,6 +296,14 @@ def load_prequantized_model(prequantized_dir, max_seq_len=4096, use_splitk_decod
# requires_grad=True, which fails. Disable grad on all parameters.
for p in model.parameters():
p.requires_grad_(False)

# Unwrap any torchao tensor subclasses (e.g. AffineQuantizedTensor) into
# parametrized plain tensors so torch.export can handle them. Mirrors the
# canonical pattern in executorch/export/stages.py SourceTransformStage.
from torchao.utils import unwrap_tensor_subclass

unwrap_tensor_subclass(model)

model.eval()

print(
Expand Down Expand Up @@ -966,6 +980,9 @@ def main(): # noqa: C901
# Register FLA Triton kernel (CUDA only)
import executorch.backends.cuda.triton.kernels # noqa: F401

if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()

if args.backend == "mlx":
if args.prequantized:
parser.error("--prequantized is not supported with --backend mlx")
Expand All @@ -988,6 +1005,14 @@ def main(): # noqa: C901

export_and_lower(model, config, args)

if args.backend == "cuda" and torch.cuda.is_available():
peak_alloc_gb = torch.cuda.max_memory_allocated() / 1e9
peak_reserved_gb = torch.cuda.max_memory_reserved() / 1e9
print(
f"[CUDA peak memory] allocated={peak_alloc_gb:.2f} GB, "
f"reserved={peak_reserved_gb:.2f} GB"
)


if __name__ == "__main__":
main()
Loading