Skip to content
55 changes: 15 additions & 40 deletions examples/quantization_w8a8_fp8/granite4_example.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,42 @@
from compressed_tensors.utils import replace_module
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.granitemoehybrid.modeling_granitemoehybrid import (
GraniteMoeHybridParallelExperts,
)

from llmcompressor import oneshot
from llmcompressor.modeling.granite4 import GraniteMoeHybridParallelExpertsLinear
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.utils import dispatch_for_generation
from llmcompressor.modeling import replace_modules_for_calibration

"""Please see details in `README_granite4.md`."""
MODEL_ID = "ibm-granite/granite-4.0-h-small"

MODEL_ID = "ibm-granite/granite-4.0-tiny-preview"

# Load model.
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

skip_router_only = True # assume we want to quantize input/output moe layers
ignore_lay = [
"lm_head",
]
if skip_router_only:
# swap moe linears to a custom class
for n, m in model.named_modules():
if isinstance(m, GraniteMoeHybridParallelExperts):
new_mod = GraniteMoeHybridParallelExpertsLinear.from_3d_expert(m)
replace_module(model, n, new_mod)
ignore_lay += ["re:.*block_sparse_moe.router"]
SAVE_DIR = "ibm-granite-4-tiny-fp8-dynamic-skipMoeRouter"
else:
# Skip all .input_linear, .output-linear, and router layers.
ignore_lay += ["re:.*block_sparse_moe"]
SAVE_DIR = "ibm-granite-4-tiny-fp8-dynamic-skipMoe"
model = replace_modules_for_calibration(model)

ignore_lay = ["lm_head"]

recipe = QuantizationModifier(
targets=["Linear", "GraniteMoeHybridParallelExpertsLinear"],
targets=["Linear"],
scheme="FP8_DYNAMIC",
ignore=ignore_lay,
)

# Apply quantization.
oneshot(model=model, recipe=recipe)

# Confirm generations of the quantized model look sane.
print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer(
"What is your favorite TV show?", return_tensors="pt"
).input_ids.to("cuda")
output = model.generate(input_ids, max_new_tokens=20)
"Describe Large Language Model", return_tensors="pt"
).input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=35)
print(tokenizer.decode(output[0]))
print("==========================================")

# Revert weights of MoE experts to 3D format (num_experts, output_size, input_size)
for n, m in model.named_modules():
if isinstance(m, GraniteMoeHybridParallelExpertsLinear):
# NOTE: can assert type != "meta" instead, which is sign of offloading
assert m.weight.device.type == "cuda", (
"Found some offloaded weights. This is not compatible with reshaping "
"experts to 3D prior model save. Ensure the model is fully on cuda."
)
m.to_3d_expert()
SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8-block"
print(f"Saving to {SAVE_DIR}")

model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)
45 changes: 45 additions & 0 deletions examples/quantization_w8a8_fp8/granite4_fp8_block_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModelForCausalLM, AutoTokenizer

from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from llmcompressor.utils import dispatch_for_generation
from llmcompressor.modeling import replace_modules_for_calibration
from llmcompressor.modeling.granite4 import pack_3d_experts

MODEL_ID = "ibm-granite/granite-4.0-h-small"

model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

model = replace_modules_for_calibration(model)

ignore_lay = ["lm_head", "re:.*block_sparse_moe.router", "re:.*mamba.in_proj", "re:.*shared_mlp.input_linear"]

recipe = QuantizationModifier(
targets=["Linear"],
scheme="FP8_BLOCK",
ignore=ignore_lay,
)

oneshot(model=model, recipe=recipe)

print("========== SAMPLE GENERATION ==============")
dispatch_for_generation(model)
input_ids = tokenizer(
"Describe Large Language Model", return_tensors="pt"
).input_ids.to(model.device)
output = model.generate(input_ids, max_new_tokens=35)
print(tokenizer.decode(output[0]))
print("==========================================")

SAVE_DIR = MODEL_ID.rstrip("/").split("/")[-1] + "-FP8-block"
print(f"Saving to {SAVE_DIR}")

model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)
pack_3d_experts(SAVE_DIR)

Loading