Skip to content

Commit 185fcab

Browse files
committed
support lora weight export.
1 parent a667117 commit 185fcab

File tree

3 files changed

+26
-3
lines changed

3 files changed

+26
-3
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ llm-export是一个llm模型导出工具,能够将llm模型导出为onnx和mnn
88
- 🚀 优化原始代码,支持动态形状
99
- 🚀 优化原始代码,减少常量部分
1010
- 🚀 使用[OnnxSlim](https://github.com/WeLoveAI/OnnxSlim)优化onnx模型,性能提升约5%; by [@inisis](https://github.com/inisis)
11-
11+
- 🚀 支持将lora权重导出为onnx和mnn
1212

1313
## 模型支持与下载
1414
- [![Download][download-chatglm-6b-onnx]][release-chatglm-6b-onnx]

README_en.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ llm-export is a tool for exporting llm models, capable of converting llm models
77
- 🚀 Optimized the original code to support dynamic shapes
88
- 🚀 Optimized the original code to reduce the constant portion
99
- 🚀 Using [OnnxSlim](https://github.com/WeLoveAI/OnnxSlim) slim onnx model,speed up 5%; by [@inisis](https://github.com/inisis)
10-
10+
- 🚀 Support export lora weight to onnx or MNN model
1111

1212
## Model Support and Downloads
1313

llm_export.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,28 @@ def get_position_ids(self) -> torch.Tensor:
990990
def get_attention_mask(self) -> torch.Tensor:
991991
return torch.ones([1, 1, 1, self.seq_len], dtype=torch.long)
992992

993+
class LoraModule(torch.nn.Module):
994+
def __init__(self, args):
995+
super().__init__()
996+
self.onnx_path = args.onnx_path
997+
self.mnn_path = args.mnn_path
998+
self.export_mnn = args.export_mnn
999+
import peft
1000+
lora_weight = peft.load_peft_weights(args.path)
1001+
for k, v in lora_weight.items():
1002+
k = k.replace('.', '/')
1003+
self.register_buffer(k, v.cpu())
1004+
1005+
def forward(self, dummpy):
1006+
return self._buffers
1007+
1008+
def export(self):
1009+
onnx_model = f'./{self.onnx_path}/lora.onnx'
1010+
torch.onnx.export(self.eval(), torch.tensor([]), onnx_model)
1011+
if self.export_mnn:
1012+
onnx2mnn(onnx_model, self.mnn_path)
1013+
1014+
9931015
if __name__ == '__main__':
9941016
llm_models = {
9951017
'chatglm-6b': Chatglm_6b,
@@ -1006,7 +1028,8 @@ def get_attention_mask(self) -> torch.Tensor:
10061028
'Yi-6B-Chat': Llama2_7b_Chat,
10071029
'deepseek-llm-7b-chat': Llama2_7b_Chat,
10081030
'phi-2': phi_2,
1009-
'bge-large-zh': bge
1031+
'bge-large-zh': bge,
1032+
'lora': LoraModule
10101033
}
10111034
parser = argparse.ArgumentParser(description='llm_exporter', formatter_class=argparse.RawTextHelpFormatter)
10121035
parser.add_argument('--path', type=str, default='THUDM/chatglm-6b', required=True,

0 commit comments

Comments
 (0)