22import argparse
33import os
44import hashlib
5+ import shutil
56from collections import OrderedDict
67
78parser = argparse .ArgumentParser (description = 'PyTorch ImageNet Validation' )
@@ -31,10 +32,9 @@ def main():
3132 if state_dict_key in checkpoint :
3233 state_dict = checkpoint [state_dict_key ]
3334 else :
34- print ("Error: No state_dict found in checkpoint {}." .format (args .checkpoint ))
35- exit (1 )
35+ state_dict = checkpoint
3636 else :
37- state_dict = checkpoint
37+ assert False
3838 for k , v in state_dict .items ():
3939 name = k [7 :] if k .startswith ('module' ) else k
4040 new_state_dict [name ] = v
@@ -43,7 +43,11 @@ def main():
4343 torch .save (new_state_dict , args .output )
4444 with open (args .output , 'rb' ) as f :
4545 sha_hash = hashlib .sha256 (f .read ()).hexdigest ()
46- print ("=> Saved state_dict to '{}, SHA256: {}'" .format (args .output , sha_hash ))
46+
47+ checkpoint_base = os .path .splitext (args .checkpoint )[0 ]
48+ final_filename = '-' .join ([checkpoint_base , sha_hash [:8 ]]) + '.pth'
49+ shutil .move (args .output , final_filename )
50+ print ("=> Saved state_dict to '{}, SHA256: {}'" .format (final_filename , sha_hash ))
4751 else :
4852 print ("Error: Checkpoint ({}) doesn't exist" .format (args .checkpoint ))
4953
0 commit comments