@@ -990,6 +990,28 @@ def get_position_ids(self) -> torch.Tensor:
990990 def get_attention_mask (self ) -> torch .Tensor :
991991 return torch .ones ([1 , 1 , 1 , self .seq_len ], dtype = torch .long )
992992
993+ class LoraModule (torch .nn .Module ):
994+ def __init__ (self , args ):
995+ super ().__init__ ()
996+ self .onnx_path = args .onnx_path
997+ self .mnn_path = args .mnn_path
998+ self .export_mnn = args .export_mnn
999+ import peft
1000+ lora_weight = peft .load_peft_weights (args .path )
1001+ for k , v in lora_weight .items ():
1002+ k = k .replace ('.' , '/' )
1003+ self .register_buffer (k , v .cpu ())
1004+
1005+ def forward (self , dummpy ):
1006+ return self ._buffers
1007+
1008+ def export (self ):
1009+ onnx_model = f'./{ self .onnx_path } /lora.onnx'
1010+ torch .onnx .export (self .eval (), torch .tensor ([]), onnx_model )
1011+ if self .export_mnn :
1012+ onnx2mnn (onnx_model , self .mnn_path )
1013+
1014+
9931015if __name__ == '__main__' :
9941016 llm_models = {
9951017 'chatglm-6b' : Chatglm_6b ,
@@ -1006,7 +1028,8 @@ def get_attention_mask(self) -> torch.Tensor:
10061028 'Yi-6B-Chat' : Llama2_7b_Chat ,
10071029 'deepseek-llm-7b-chat' : Llama2_7b_Chat ,
10081030 'phi-2' : phi_2 ,
1009- 'bge-large-zh' : bge
1031+ 'bge-large-zh' : bge ,
1032+ 'lora' : LoraModule
10101033 }
10111034 parser = argparse .ArgumentParser (description = 'llm_exporter' , formatter_class = argparse .RawTextHelpFormatter )
10121035 parser .add_argument ('--path' , type = str , default = 'THUDM/chatglm-6b' , required = True ,
0 commit comments