@@ -253,7 +253,189 @@ def replace_moe_modules(module, name=''):
253253 return model
254254
255255
256+ def pack_3d_experts (
257+ source_dir ,
258+ validate = True ,
259+ backup = True ,
260+ allow_missing_experts = False ,
261+ verbose = True
262+ ):
263+ """
264+ Transform MoE model from per-expert storage to 3D stacked tensors.
265+
266+ From: model.layers.{L}.block_sparse_moe.{linear_type}.experts.{E}.{param}
267+ To: model.layers.{L}.block_sparse_moe.{linear_type}.{param}
268+
269+ Args:
270+ source_dir: Model directory path
271+ validate: Validate shapes and expert continuity
272+ backup: Create backup before modification (RECOMMENDED)
273+ allow_missing_experts: Don't fail if some experts are missing
274+ verbose: Print progress messages
275+ """
276+ source_dir = Path (source_dir )
277+ index_file = source_dir / "model.safetensors.index.json"
278+ backup_dir = None
279+ temp_files = []
280+
281+ def log (msg ):
282+ if verbose : print (msg )
283+
284+ try :
285+ # === BACKUP ===
286+ if backup :
287+ backup_dir = source_dir .parent / f"{ source_dir .name } .backup.{ datetime .now ().strftime ('%Y%m%d_%H%M%S' )} "
288+ backup_dir .mkdir (parents = True )
289+ for f in source_dir .glob ("*.safetensors*" ):
290+ shutil .copy2 (f , backup_dir / f .name )
291+ log (f"✓ Backup created at { backup_dir } " )
292+
293+ # === LOAD INDEX ===
294+ with open (index_file ) as f :
295+ index_data = json .load (f )
296+ weight_map = index_data ["weight_map" ]
297+
298+ # === GROUP TENSORS ===
299+ grouped = defaultdict (dict ) # {(layer, linear_type, param): {expert_num: (name, file)}}
300+ other = {}
301+
302+ for name , file in weight_map .items ():
303+ if ".block_sparse_moe." in name and ".experts." in name :
304+ parts = name .split ("." )
305+ try :
306+ layer = int (parts [parts .index ("layers" ) + 1 ])
307+ expert = int (parts [parts .index ("experts" ) + 1 ])
308+ linear_type = parts [parts .index ("experts" ) - 1 ]
309+ param = "." .join (parts [parts .index ("experts" ) + 2 :])
310+ grouped [(layer , linear_type , param )][expert ] = (name , file )
311+ except (ValueError , IndexError ):
312+ other [name ] = file
313+ else :
314+ other [name ] = file
315+
316+ log (f"✓ Found { len (grouped )} expert groups, { len (other )} other tensors" )
317+
318+ # === LOAD FILES ===
319+ log ("Loading files..." )
320+ loaded = {}
321+ old_files = set (weight_map .values ())
322+ for file in old_files :
323+ loaded [file ] = load_file (str (source_dir / file ))
324+
325+ # === STACK EXPERTS ===
326+ log ("Stacking experts..." )
327+ new_tensors = {}
328+
329+ for (layer , linear_type , param ), experts in sorted (grouped .items ()):
330+ expert_nums = sorted (experts .keys ())
331+
332+ # Validate
333+ if validate :
334+ # Check continuity
335+ expected = list (range (len (expert_nums )))
336+ if expert_nums != expected :
337+ missing = set (expected ) - set (expert_nums )
338+ if missing and not allow_missing_experts :
339+ raise ValueError (f"Missing experts { missing } in layer { layer } , { linear_type } .{ param } " )
340+
341+ # Check shapes and dtypes
342+ shapes = [loaded [experts [e ][1 ]][experts [e ][0 ]].shape for e in expert_nums ]
343+ dtypes = [loaded [experts [e ][1 ]][experts [e ][0 ]].dtype for e in expert_nums ]
344+ if len (set (shapes )) > 1 :
345+ raise ValueError (f"Shape mismatch in layer { layer } , { linear_type } .{ param } : { set (shapes )} " )
346+ if len (set (dtypes )) > 1 :
347+ raise ValueError (f"Dtype mismatch in layer { layer } , { linear_type } .{ param } : { set (dtypes )} " )
348+
349+ # Stack
350+ tensors = [loaded [experts [e ][1 ]][experts [e ][0 ]] for e in expert_nums ]
351+ stacked = torch .stack (tensors , dim = 0 )
352+ new_name = f"model.layers.{ layer } .block_sparse_moe.{ linear_type } .{ param } "
353+ new_tensors [new_name ] = stacked
354+ log (f" Layer { layer } { linear_type } .{ param } : { list (stacked .shape )} " )
355+
356+ # Copy other tensors
357+ for name , file in other .items ():
358+ new_tensors [name ] = loaded [file ][name ]
359+
360+ # === DISTRIBUTE ACROSS FILES ===
361+ log ("Distributing tensors..." )
362+ num_files = len (old_files )
363+ tensor_sizes = [(n , t .numel () * t .element_size ()) for n , t in new_tensors .items ()]
364+ tensor_sizes .sort (key = lambda x : x [1 ], reverse = True )
365+
366+ file_tensors = [{} for _ in range (num_files )]
367+ file_sizes = [0 ] * num_files
368+ new_weight_map = {}
369+
370+ for name , size in tensor_sizes :
371+ min_idx = file_sizes .index (min (file_sizes ))
372+ file_tensors [min_idx ][name ] = new_tensors [name ]
373+ file_sizes [min_idx ] += size
374+ new_weight_map [name ] = f"model-{ min_idx + 1 :05d} -of-{ num_files :05d} .safetensors"
375+
376+ # === SAVE FILES (TEMP) ===
377+ log ("Saving files..." )
378+ saved_files = []
379+ for i , tensors in enumerate (file_tensors ):
380+ if tensors :
381+ file_name = f"model-{ i + 1 :05d} -of-{ num_files :05d} .safetensors"
382+ temp_name = f"{ file_name } .tmp"
383+ temp_path = source_dir / temp_name
384+ save_file (tensors , str (temp_path ))
385+ temp_files .append (temp_path )
386+ saved_files .append ((temp_name , file_name ))
387+
388+ # Save index (temp)
389+ temp_index = source_dir / "model.safetensors.index.json.tmp"
390+ with open (temp_index , "w" ) as f :
391+ json .dump ({"metadata" : index_data .get ("metadata" , {}), "weight_map" : new_weight_map }, f , indent = 2 )
392+ temp_files .append (temp_index )
393+
394+ # === FINALIZE (DELETE OLD, RENAME TEMP) ===
395+ log ("Finalizing..." )
396+ # Delete old
397+ for old in old_files :
398+ (source_dir / old ).unlink ()
399+ index_file .unlink ()
400+
401+ # Rename temp
402+ for temp , final in saved_files :
403+ (source_dir / temp ).rename (source_dir / final )
404+ temp_index .rename (index_file )
405+ temp_files .clear ()
406+
407+ # === VERIFY ===
408+ if validate :
409+ with open (index_file ) as f :
410+ check = json .load (f )
411+ remaining_experts = [n for n in check ["weight_map" ] if ".experts." in n ]
412+ if remaining_experts :
413+ raise ValueError (f"Verification failed: { len (remaining_experts )} unpacked experts remain" )
414+
415+ log (f"✓ Success! Transformed { len (grouped )} expert groups" )
416+
417+ except Exception as e :
418+ log (f"✗ Error: { e } " )
419+
420+ # === ROLLBACK ===
421+ if backup and backup_dir and backup_dir .exists ():
422+ log ("Rolling back..." )
423+ for temp in temp_files :
424+ if temp .exists (): temp .unlink ()
425+ for f in source_dir .glob ("*.safetensors*" ):
426+ f .unlink ()
427+ for f in backup_dir .glob ("*" ):
428+ shutil .copy2 (f , source_dir / f .name )
429+ log ("✓ Rolled back to backup" )
430+
431+ raise
432+
433+ finally :
434+ # Cleanup temp files
435+ for temp in temp_files :
436+ if temp .exists (): temp .unlink ()
256437
438+
257439class GraniteMoeHybridParallelExpertsLinear (torch .nn .Linear ):
258440 def __init__ (self , num_experts : int , input_size : int , output_size : int ) -> None :
259441 """Use a real Linear so that llmcompressor and vllm can handle it easier.
0 commit comments