Skip to content

Commit 7c7a188

Browse files
committed
Update pack experts
1 parent 80f779c commit 7c7a188

File tree

2 files changed

+185
-0
lines changed

2 files changed

+185
-0
lines changed

examples/quantization_w8a8_fp8/granite4_fp8_block_example.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from llmcompressor.modifiers.quantization import QuantizationModifier
99
from llmcompressor.utils import dispatch_for_generation
1010
from llmcompressor.modeling import replace_modules_for_calibration
11+
from llmcompressor.modeling.granite4 import pack_3d_experts
1112

1213
MODEL_ID = "ibm-granite/granite-4.0-h-small"
1314

@@ -40,3 +41,5 @@
4041

4142
model.save_pretrained(SAVE_DIR)
4243
tokenizer.save_pretrained(SAVE_DIR)
44+
pack_3d_experts(SAVE_DIR)
45+

src/llmcompressor/modeling/granite4.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,189 @@ def replace_moe_modules(module, name=''):
253253
return model
254254

255255

256+
def pack_3d_experts(
257+
source_dir,
258+
validate=True,
259+
backup=True,
260+
allow_missing_experts=False,
261+
verbose=True
262+
):
263+
"""
264+
Transform MoE model from per-expert storage to 3D stacked tensors.
265+
266+
From: model.layers.{L}.block_sparse_moe.{linear_type}.experts.{E}.{param}
267+
To: model.layers.{L}.block_sparse_moe.{linear_type}.{param}
268+
269+
Args:
270+
source_dir: Model directory path
271+
validate: Validate shapes and expert continuity
272+
backup: Create backup before modification (RECOMMENDED)
273+
allow_missing_experts: Don't fail if some experts are missing
274+
verbose: Print progress messages
275+
"""
276+
source_dir = Path(source_dir)
277+
index_file = source_dir / "model.safetensors.index.json"
278+
backup_dir = None
279+
temp_files = []
280+
281+
def log(msg):
282+
if verbose: print(msg)
283+
284+
try:
285+
# === BACKUP ===
286+
if backup:
287+
backup_dir = source_dir.parent / f"{source_dir.name}.backup.{datetime.now().strftime('%Y%m%d_%H%M%S')}"
288+
backup_dir.mkdir(parents=True)
289+
for f in source_dir.glob("*.safetensors*"):
290+
shutil.copy2(f, backup_dir / f.name)
291+
log(f"✓ Backup created at {backup_dir}")
292+
293+
# === LOAD INDEX ===
294+
with open(index_file) as f:
295+
index_data = json.load(f)
296+
weight_map = index_data["weight_map"]
297+
298+
# === GROUP TENSORS ===
299+
grouped = defaultdict(dict) # {(layer, linear_type, param): {expert_num: (name, file)}}
300+
other = {}
301+
302+
for name, file in weight_map.items():
303+
if ".block_sparse_moe." in name and ".experts." in name:
304+
parts = name.split(".")
305+
try:
306+
layer = int(parts[parts.index("layers") + 1])
307+
expert = int(parts[parts.index("experts") + 1])
308+
linear_type = parts[parts.index("experts") - 1]
309+
param = ".".join(parts[parts.index("experts") + 2:])
310+
grouped[(layer, linear_type, param)][expert] = (name, file)
311+
except (ValueError, IndexError):
312+
other[name] = file
313+
else:
314+
other[name] = file
315+
316+
log(f"✓ Found {len(grouped)} expert groups, {len(other)} other tensors")
317+
318+
# === LOAD FILES ===
319+
log("Loading files...")
320+
loaded = {}
321+
old_files = set(weight_map.values())
322+
for file in old_files:
323+
loaded[file] = load_file(str(source_dir / file))
324+
325+
# === STACK EXPERTS ===
326+
log("Stacking experts...")
327+
new_tensors = {}
328+
329+
for (layer, linear_type, param), experts in sorted(grouped.items()):
330+
expert_nums = sorted(experts.keys())
331+
332+
# Validate
333+
if validate:
334+
# Check continuity
335+
expected = list(range(len(expert_nums)))
336+
if expert_nums != expected:
337+
missing = set(expected) - set(expert_nums)
338+
if missing and not allow_missing_experts:
339+
raise ValueError(f"Missing experts {missing} in layer {layer}, {linear_type}.{param}")
340+
341+
# Check shapes and dtypes
342+
shapes = [loaded[experts[e][1]][experts[e][0]].shape for e in expert_nums]
343+
dtypes = [loaded[experts[e][1]][experts[e][0]].dtype for e in expert_nums]
344+
if len(set(shapes)) > 1:
345+
raise ValueError(f"Shape mismatch in layer {layer}, {linear_type}.{param}: {set(shapes)}")
346+
if len(set(dtypes)) > 1:
347+
raise ValueError(f"Dtype mismatch in layer {layer}, {linear_type}.{param}: {set(dtypes)}")
348+
349+
# Stack
350+
tensors = [loaded[experts[e][1]][experts[e][0]] for e in expert_nums]
351+
stacked = torch.stack(tensors, dim=0)
352+
new_name = f"model.layers.{layer}.block_sparse_moe.{linear_type}.{param}"
353+
new_tensors[new_name] = stacked
354+
log(f" Layer {layer} {linear_type}.{param}: {list(stacked.shape)}")
355+
356+
# Copy other tensors
357+
for name, file in other.items():
358+
new_tensors[name] = loaded[file][name]
359+
360+
# === DISTRIBUTE ACROSS FILES ===
361+
log("Distributing tensors...")
362+
num_files = len(old_files)
363+
tensor_sizes = [(n, t.numel() * t.element_size()) for n, t in new_tensors.items()]
364+
tensor_sizes.sort(key=lambda x: x[1], reverse=True)
365+
366+
file_tensors = [{} for _ in range(num_files)]
367+
file_sizes = [0] * num_files
368+
new_weight_map = {}
369+
370+
for name, size in tensor_sizes:
371+
min_idx = file_sizes.index(min(file_sizes))
372+
file_tensors[min_idx][name] = new_tensors[name]
373+
file_sizes[min_idx] += size
374+
new_weight_map[name] = f"model-{min_idx+1:05d}-of-{num_files:05d}.safetensors"
375+
376+
# === SAVE FILES (TEMP) ===
377+
log("Saving files...")
378+
saved_files = []
379+
for i, tensors in enumerate(file_tensors):
380+
if tensors:
381+
file_name = f"model-{i+1:05d}-of-{num_files:05d}.safetensors"
382+
temp_name = f"{file_name}.tmp"
383+
temp_path = source_dir / temp_name
384+
save_file(tensors, str(temp_path))
385+
temp_files.append(temp_path)
386+
saved_files.append((temp_name, file_name))
387+
388+
# Save index (temp)
389+
temp_index = source_dir / "model.safetensors.index.json.tmp"
390+
with open(temp_index, "w") as f:
391+
json.dump({"metadata": index_data.get("metadata", {}), "weight_map": new_weight_map}, f, indent=2)
392+
temp_files.append(temp_index)
393+
394+
# === FINALIZE (DELETE OLD, RENAME TEMP) ===
395+
log("Finalizing...")
396+
# Delete old
397+
for old in old_files:
398+
(source_dir / old).unlink()
399+
index_file.unlink()
400+
401+
# Rename temp
402+
for temp, final in saved_files:
403+
(source_dir / temp).rename(source_dir / final)
404+
temp_index.rename(index_file)
405+
temp_files.clear()
406+
407+
# === VERIFY ===
408+
if validate:
409+
with open(index_file) as f:
410+
check = json.load(f)
411+
remaining_experts = [n for n in check["weight_map"] if ".experts." in n]
412+
if remaining_experts:
413+
raise ValueError(f"Verification failed: {len(remaining_experts)} unpacked experts remain")
414+
415+
log(f"✓ Success! Transformed {len(grouped)} expert groups")
416+
417+
except Exception as e:
418+
log(f"✗ Error: {e}")
419+
420+
# === ROLLBACK ===
421+
if backup and backup_dir and backup_dir.exists():
422+
log("Rolling back...")
423+
for temp in temp_files:
424+
if temp.exists(): temp.unlink()
425+
for f in source_dir.glob("*.safetensors*"):
426+
f.unlink()
427+
for f in backup_dir.glob("*"):
428+
shutil.copy2(f, source_dir / f.name)
429+
log("✓ Rolled back to backup")
430+
431+
raise
432+
433+
finally:
434+
# Cleanup temp files
435+
for temp in temp_files:
436+
if temp.exists(): temp.unlink()
256437

438+
257439
class GraniteMoeHybridParallelExpertsLinear(torch.nn.Linear):
258440
def __init__(self, num_experts: int, input_size: int, output_size: int) -> None:
259441
"""Use a real Linear so that llmcompressor and vllm can handle it easier.

0 commit comments

Comments
 (0)