Skip to content

Commit 2fc3658

Browse files
committed
[feat] support gemma-2
1 parent 1fc1880 commit 2fc3658

File tree

1 file changed

+77
-30
lines changed

1 file changed

+77
-30
lines changed

llm_export.py

Lines changed: 77 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,15 @@ def regist(self, model_type, model_map):
8282

8383
def regist_models(self):
8484
self.defualt_map()
85+
# regist models
86+
self.regist_llama()
87+
self.regist_qwen()
88+
self.regist_glm()
89+
self.regist_glm2()
90+
self.regist_phi()
91+
self.regist_gemma2()
92+
93+
def regist_llama(self):
8594
llama_map = self.default_map
8695
self.regist('llama', llama_map)
8796
self.regist('qwen2', llama_map)
@@ -92,10 +101,6 @@ def regist_models(self):
92101
'o_proj': 'o_proj'
93102
}
94103
self.regist('baichuan', baichuan_map)
95-
self.regist_qwen()
96-
self.regist_glm()
97-
self.regist_glm2()
98-
self.regist_phi()
99104

100105
def regist_qwen(self):
101106
qwen_map = {
@@ -204,6 +209,20 @@ def regist_phi(self):
204209
}
205210
self.regist('phi-msft', phi_map)
206211

212+
def regist_gemma2(self):
213+
gemma2_config = copy.deepcopy(self.default_config)
214+
gemma2_config['head_dim'] = 'head_dim'
215+
gemma2_decoder = copy.deepcopy(self.default_decoder)
216+
gemma2_decoder['pre_feedforward_layernorm'] = 'pre_feedforward_layernorm'
217+
gemma2_decoder['post_feedforward_layernorm'] = 'post_feedforward_layernorm'
218+
gemma2_map = {
219+
'config': gemma2_config,
220+
'model': self.defualt_model,
221+
'decoder': gemma2_decoder,
222+
'attention': self.default_attention
223+
}
224+
self.regist('gemma2', gemma2_map)
225+
207226
def defualt_map(self):
208227
# default map is `LlamaForCausalLM`
209228
self.config_key = 'config'
@@ -356,10 +375,11 @@ def rebuild(self):
356375
return self.onnx_weight_path
357376

358377
class MNNConveter:
359-
def __init__(self, onnx_path, weight_ops, quant_bit = 4,quant_block = 0):
378+
def __init__(self, onnx_path, weight_ops, config):
360379
self.weight_ops = weight_ops
361-
self.quant_block = quant_block
362-
self.quant_bit = quant_bit
380+
self.quant_block = config.quant_block
381+
self.quant_bit = config.quant_bit
382+
self.lm_quant_bit = config.lm_quant_bit
363383
self.mnn_weight_offset = 0
364384
self.onnx_model_path = onnx_path
365385
self.mnn_model_path = onnx_path.replace('.onnx', '.mnn')
@@ -458,49 +478,49 @@ def rebuild(self, json_path):
458478
json.dump(mnn_graph, file, ensure_ascii=False, indent=4)
459479
return self.mnn_weight_path
460480

461-
def quant(self, weight):
481+
def quant(self, weight, quant_bit, quant_block):
462482
weight = weight.numpy()
463483
oc, ic = weight.shape
464-
if self.quant_block == 0:
484+
if quant_block == 0:
465485
block_size = ic
466486
else:
467-
block_size = self.quant_block
487+
block_size = quant_block
468488
block_num = ic // block_size
469489
weight = weight.reshape(oc, block_num, block_size)
470490
max_val = np.max(weight, axis=-1, keepdims=True)
471491
min_val = np.min(weight, axis=-1, keepdims=True)
472-
offset = 1 << (self.quant_bit - 1)
492+
offset = 1 << (quant_bit - 1)
473493
clip_max = offset - 1
474494
clip_min = -offset
475495
scale = (max_val - min_val) / (clip_max - clip_min)
476496
q_weight = np.round((weight - min_val) / scale).astype(np.int8) + clip_min
477497
q_weight = (np.clip(q_weight.flatten(), clip_min, clip_max) + offset).astype(np.uint8)
478498
q_weight = q_weight.reshape(-1, 2)
479-
if self.quant_bit == 4:
499+
if quant_bit == 4:
480500
q_weight = q_weight[:, 0] * 16 + q_weight[:, 1]
481501
alpha = np.stack([min_val.flatten(), scale.flatten()], axis=-1).flatten()
482502
return q_weight, alpha, clip_min
483503

484504
def write_npy(self, data):
485505
return self.mnn_weight.write(data.tobytes())
486506

487-
def write_header(self, ic, oc):
507+
def write_header(self, ic, oc, quant_bit):
488508
dim_num = self.mnn_weight.write(b'\x02')
489509
shape_dtype = np.int16
490510
if oc > 65535 or ic > 65535:
491511
shape_dtype = np.int32
492512
dim_length = self.write_npy(np.array([oc, ic]).astype(shape_dtype))
493-
offset = 1 << (self.quant_bit - 1)
513+
offset = 1 << (quant_bit - 1)
494514
weight_map = [i for i in range(-offset, offset)]
495515
weight_map.insert(0, len(weight_map))
496516
map_length = self.write_npy(np.array(weight_map, dtype=np.int8))
497517
header_length = dim_num + dim_length + map_length
498518
return header_length, shape_dtype == np.int32
499519

500-
def build_weight(self, linear):
520+
def build_weight(self, linear, quant_bit, quant_block):
501521
ic, oc = linear.in_features, linear.out_features
502-
q_weight, alpha, q_min = self.quant(linear.weight.data)
503-
header_len, shape_int32 = self.write_header(ic, oc)
522+
q_weight, alpha, q_min = self.quant(linear.weight.data, quant_bit, quant_block)
523+
header_len, shape_int32 = self.write_header(ic, oc, quant_bit)
504524
weight_len = self.write_npy(q_weight) + header_len
505525
alpha_len = self.write_npy(alpha)
506526
if linear.bias is not None:
@@ -535,7 +555,10 @@ def rebuild_op(self, op, graph):
535555
linear.out_features == oc and
536556
(linear.bias is not None) == has_bias)
537557

538-
external, q_min, shape_int32 = self.build_weight(linear)
558+
559+
quant_bit = self.lm_quant_bit if 'lm_head' in name else self.quant_bit
560+
print(f'quant layer {name}: bits {quant_bit}, block {self.quant_block}')
561+
external, q_min, shape_int32 = self.build_weight(linear, quant_bit, self.quant_block)
539562

540563
origin_input = op['inputIndexes']
541564
origin_output = op['outputIndexes']
@@ -630,9 +653,13 @@ def __init__(self, embed, config):
630653
super().__init__()
631654
self.hidden_size = config.hidden_size
632655
self.embed = embed
656+
if config.model_type == 'gemma2':
657+
normalizer = torch.tensor(self.hidden_size**0.5)
658+
self.embed.weight.data *= normalizer
633659

634660
def forward(self, input_ids):
635-
return self.embed(input_ids).view(-1, 1, self.hidden_size)
661+
inputs_embeds = self.embed(input_ids).view(-1, 1, self.hidden_size)
662+
return inputs_embeds
636663

637664
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
638665
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
@@ -695,6 +722,7 @@ def forward(
695722
query_states = self.q_proj(hidden_states)
696723
key_states = self.k_proj(hidden_states)
697724
value_states = self.v_proj(hidden_states)
725+
# print(f'hidden_states.shape = {hidden_states.shape}, query_states.shape = {query_states.shape}')
698726
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
699727
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
700728
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
@@ -734,7 +762,7 @@ def forward(
734762
attn_output = torch.matmul(attn_weights, value_states)
735763

736764
attn_output = attn_output.transpose(1, 2).contiguous()
737-
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
765+
attn_output = attn_output.reshape(bsz, q_len, -1)
738766
attn_output = self.o_proj(attn_output)
739767
return attn_output, past_key_value
740768

@@ -832,22 +860,32 @@ def forward(
832860
past_key_value=past_key_value,
833861
)
834862
# Fully Connected
835-
if self.alpha != 1.0:
863+
if not hasattr(self, 'post_attention_layernorm'):
864+
# phi
865+
feed_forward_hidden_states = self.mlp(norm_hidden_states)
866+
hidden_states = hidden_states + feed_forward_hidden_states + residual
867+
elif self.alpha != 1.0:
836868
# chatglm-6b
837869
hidden_states = norm_hidden_states * self.alpha + hidden_states
838870
mlp_input = self.post_attention_layernorm(hidden_states)
839871
mlp_output = self.mlp(mlp_input)
840872
hidden_states = mlp_input * self.alpha + mlp_output
841-
elif hasattr(self, 'post_attention_layernorm'):
873+
elif hasattr(self, 'pre_feedforward_layernorm'):
874+
# gemma2
875+
hidden_states = self.post_attention_layernorm(hidden_states)
842876
hidden_states = residual + hidden_states
843877
residual = hidden_states
844-
hidden_states = self.post_attention_layernorm(hidden_states)
878+
hidden_states = self.pre_feedforward_layernorm(hidden_states)
845879
hidden_states = self.mlp(hidden_states)
880+
hidden_states = self.post_feedforward_layernorm(hidden_states)
846881
hidden_states = residual + hidden_states
847882
else:
848-
# phi
849-
feed_forward_hidden_states = self.mlp(norm_hidden_states)
850-
hidden_states = hidden_states + feed_forward_hidden_states + residual
883+
# general
884+
hidden_states = residual + hidden_states
885+
residual = hidden_states
886+
hidden_states = self.post_attention_layernorm(hidden_states)
887+
hidden_states = self.mlp(hidden_states)
888+
hidden_states = residual + hidden_states
851889

852890
return hidden_states, present_key_value
853891

@@ -886,6 +924,10 @@ def init_from_args(self, args):
886924
self.skip_slim = args.skip_slim
887925
self.quant_bit = args.quant_bit
888926
self.quant_block = args.quant_block
927+
if args.lm_quant_bit is not None:
928+
self.lm_quant_bit = args.lm_quant_bit
929+
else:
930+
self.lm_quant_bit = self.quant_bit
889931
# init export dst dir
890932
if not os.path.exists(self.dst_path):
891933
os.makedirs(self.dst_path)
@@ -920,6 +962,7 @@ def load_model(self, model_path):
920962
self.stop_ids.append(id)
921963
self.stop_ids = [stop_id for stop_id in self.stop_ids if stop_id is not None]
922964
model_mapper = ModelMapper()
965+
923966
self.model_type, self.model_map = model_mapper.get_map(self.config)
924967
# print(self.model)
925968
# print(self.model_type, self.model_map)
@@ -929,7 +972,8 @@ def load_model(self, model_path):
929972
self.num_key_value_heads = self.num_attention_heads
930973
if not hasattr(self, 'rope_theta') or self.rope_theta is None:
931974
self.rope_theta = 10000.0
932-
self.head_dim = self.hidden_size // self.num_attention_heads
975+
if not hasattr(self, 'head_dim') or self.head_dim is None:
976+
self.head_dim = self.hidden_size // self.num_attention_heads
933977
# some export info
934978
self.past_kv_shape = [self.num_hidden_layers, 2, 1, 0, self.num_key_value_heads, self.head_dim]
935979
self.block_dynamic_axes = {
@@ -1082,6 +1126,8 @@ def build_prompt(self, query):
10821126
return f'{query}[gMASK]<sop>'
10831127
if 'phi-2' in self.model_name:
10841128
return f'Instruct: {query}\nOutput:'
1129+
if 'gemma-2' in self.model_name:
1130+
return f'<bos><start_of_turn>user\n{query}<end_of_turn>\n<start_of_turn>model\n'
10851131
return query
10861132

10871133
def str_to_ids(self, prompt):
@@ -1246,7 +1292,7 @@ def export(self, export_type):
12461292
self.onnx_slim(onnx_model)
12471293
if export_mnn:
12481294
# convert onnx to mnn and quant weight
1249-
MNNConveter(onnx_model, self.unloaded_ops, self.quant_bit, self.quant_block).export()
1295+
MNNConveter(onnx_model, self.unloaded_ops, self).export()
12501296
else:
12511297
# export weight to llm.onnx.data
12521298
self.onnx_load_param(onnx_model)
@@ -1479,7 +1525,7 @@ def export(self, export_type):
14791525
if not self.skip_slim:
14801526
self.onnx_slim(onnx_model)
14811527
if 'mnn' in export_type:
1482-
MNNConveter(onnx_model, None, self.quant_bit, self.quant_block).export()
1528+
MNNConveter(onnx_model, None, self).export()
14831529

14841530
def build_prompt(self, query):
14851531
return f'[CLS]{query}[SEP]'
@@ -1506,7 +1552,8 @@ def get_attention_mask(self) -> torch.Tensor:
15061552
parser.add_argument('--export', type=str, default=None, help='export model to an onnx/mnn model.')
15071553
parser.add_argument('--skip_slim', action='store_true', help='Whether or not to skip onnx-slim.')
15081554
parser.add_argument('--quant_bit', type=int, default=4, help='mnn quant bit, 4 or 8, default is 4.')
1509-
parser.add_argument('--quant_block', type=int, default=0, help='mnn quant block, default is 0 mean channle-wise.')
1555+
parser.add_argument('--quant_block', type=int, default=128, help='mnn quant block, default is 0 mean channle-wise.')
1556+
parser.add_argument('--lm_quant_bit', type=int, default=None, help='mnn lm_head quant bit, 4 or 8, default is `quant_bit`.')
15101557

15111558
args = parser.parse_args()
15121559
model_path = args.path

0 commit comments

Comments
 (0)