@@ -14,6 +14,7 @@ class Embedding(torch.nn.Module):
1414 def __init__ (self , embed , using_bf16 : bool = False ):
1515 super ().__init__ ()
1616 self .bf16 = using_bf16
17+ self .embed_dim = embed .weight .shape [- 1 ]
1718 if using_bf16 :
1819 # using bf16 embedding weight
1920 self .embed = embed .bfloat16 ()
@@ -24,7 +25,7 @@ def forward(self, input_ids):
2425 res = self .embed (input_ids )
2526 if self .bf16 :
2627 res = res .float ()
27- return res .view (- 1 , 1 , 4096 )
28+ return res .view (- 1 , 1 , self . embed_dim )
2829
2930class Lm (torch .nn .Module ):
3031 def __init__ (self , lm ):
@@ -52,8 +53,9 @@ def __init__(self, args):
5253 self .sp_model = spm .SentencePieceProcessor (tokenizer_model )
5354 else :
5455 self .sp_model = None
55- self .load_model (args .path )
5656 self .max_length = 1024
57+ self .hidden_size = 4096
58+ self .load_model (args .path )
5759
5860 def load_model (self , model_path : str ):
5961 raise NotImplementedError
@@ -131,7 +133,7 @@ def assert_equal(self, torch_outs, onnx_outs):
131133
132134 def export_lm (self ):
133135 model = self .lm
134- hidden_states = torch .randn (1 , 4096 )
136+ hidden_states = torch .randn (1 , self . hidden_size )
135137 onnx_model = f'./{ self .export_path } /lm.onnx'
136138 torch .onnx .export (model , (hidden_states ),
137139 onnx_model ,
@@ -177,7 +179,7 @@ def export_embed(self):
177179 def export_block (self , block_id : int ):
178180 self .seq_len = 3
179181 self .token_len = 0
180- inputs_embeds = torch .randn ((self .seq_len , 1 , 4096 ))
182+ inputs_embeds = torch .randn ((self .seq_len , 1 , self . hidden_size ))
181183 attention_mask = self .get_attention_mask ()
182184 position_ids = self .get_position_ids ()
183185 past_key_values = torch .zeros (self .past_kv_shape [1 :])
@@ -294,7 +296,7 @@ def forward(self, hidden_states, attention_mask, position_ids, past_kv):
294296 use_cache = True )
295297 if self .final_layernorm is not None :
296298 hidden_states = self .final_layernorm (hidden_states )
297- hidden_states = hidden_states .view (- 1 , 4096 )[- 1 ].view (1 , 1 , 4096 )
299+ hidden_states = hidden_states .view (- 1 , self . hidden_size )[- 1 ].view (1 , 1 , self . hidden_size )
298300 if isinstance (presents , tuple ):
299301 presents = torch .stack (presents )
300302 return hidden_states , presents
@@ -376,7 +378,7 @@ def forward(self, hidden_states, attention_mask, position_ids, past_kv):
376378 rotary_pos_emb = rotary_pos_emb )
377379 if self .final_layernorm is not None :
378380 hidden_states = self .final_layernorm (hidden_states )
379- hidden_states = hidden_states .view (- 1 , 4096 )[- 1 ].view (1 , 1 , 4096 )
381+ hidden_states = hidden_states .view (- 1 , self . hidden_size )[- 1 ].view (1 , 1 , self . hidden_size )
380382 if isinstance (presents , tuple ):
381383 presents = torch .stack (presents )
382384 return hidden_states , presents
@@ -442,27 +444,55 @@ def build_prompt(self, query):
442444
443445# qwen
444446class QWENBlock (torch .nn .Module ):
445- def __init__ (self , block , block_id , final_layernorm = None ):
447+ def __init__ (self , block , block_id , hidden_size , final_layernorm = None ):
446448 super ().__init__ ()
447449 self .block = block
448450 self .block_id = block_id
449451 self .final_layernorm = final_layernorm
452+ self .hidden_size = hidden_size
450453
451454 def forward (self , hidden_states , attention_mask , position_ids , past_kv ):
452455 theta = 1.0 / (10000.0 ** (torch .arange (0 , 128 , 2 , dtype = torch .float32 ) / 128 ))
453456 position_ids = position_ids .float ().reshape (- 1 , 1 )
454457 idx_theta = position_ids * theta
455458 rotary_pos_emb = torch .cat ((idx_theta , idx_theta ), dim = - 1 )
456459 rotary_pos_emb = rotary_pos_emb .unsqueeze (1 ).unsqueeze (0 )
457- hidden_states = hidden_states .view (1 , - 1 , 4096 )
460+ hidden_states = hidden_states .view (1 , - 1 , self . hidden_size )
458461 hidden_states , presents = self .block (hidden_states ,
459462 past_kv ,
460463 attention_mask ,
461464 rotary_pos_emb ,
462465 use_cache = True )
463466 if self .final_layernorm is not None :
464467 hidden_states = self .final_layernorm (hidden_states )
465- hidden_states = hidden_states .view (- 1 , 4096 )[- 1 ].view (1 , 1 , 4096 )
468+ hidden_states = hidden_states .view (- 1 , self .hidden_size )[- 1 ].view (1 , 1 , self .hidden_size )
469+ if isinstance (presents , tuple ):
470+ presents = torch .stack (presents )
471+ return hidden_states , presents
472+
473+ class QWEN18Block (torch .nn .Module ):
474+ def __init__ (self , block , block_id , hidden_size , final_layernorm = None ):
475+ super ().__init__ ()
476+ self .block = block
477+ self .block_id = block_id
478+ self .final_layernorm = final_layernorm
479+ self .hidden_size = hidden_size
480+
481+ def forward (self , hidden_states , attention_mask , position_ids , past_kv ):
482+ theta = 1.0 / (10000.0 ** (torch .arange (0 , 128 , 2 , dtype = torch .float32 ) / 128 ))
483+ position_ids = position_ids .float ().reshape (- 1 , 1 )
484+ idx_theta = position_ids * theta
485+ rotary_pos_emb = torch .cat ((idx_theta , idx_theta ), dim = - 1 ).unsqueeze (1 ).unsqueeze (0 )
486+ rotary_pos_emb = torch .stack ([torch .cos (rotary_pos_emb ), torch .sin (rotary_pos_emb )])
487+ hidden_states = hidden_states .view (1 , - 1 , self .hidden_size )
488+ hidden_states , presents = self .block (hidden_states ,
489+ rotary_pos_emb ,
490+ past_kv ,
491+ attention_mask ,
492+ use_cache = True )
493+ if self .final_layernorm is not None :
494+ hidden_states = self .final_layernorm (hidden_states )
495+ hidden_states = hidden_states .view (- 1 , self .hidden_size )[- 1 ].view (1 , 1 , self .hidden_size )
466496 if isinstance (presents , tuple ):
467497 presents = torch .stack (presents )
468498 return hidden_states , presents
@@ -482,11 +512,18 @@ def load_model(self, model_path: str):
482512 # some wrapper
483513 self .stop_id = self .tokenizer .im_end_id
484514 self .block_nums = len (self .blocks_ )
515+ self .hidden_size = transformer .embed_dim
485516 self .embed = Embedding (self .embed_ , self .embed_bf16 )
486517 self .lm = Lm (self .lm_ )
487- self .blocks = [QWENBlock (self .blocks_ [i ], i , self .final_layernorm_ if i == len (self .blocks_ ) - 1 else None ) for i in range (self .block_nums )]
518+ if self .block_nums == 32 :
519+ # qwen-7b
520+ self .blocks = [QWENBlock (self .blocks_ [i ], i , self .hidden_size , self .final_layernorm_ if i == len (self .blocks_ ) - 1 else None ) for i in range (self .block_nums )]
521+ self .past_kv_shape = [32 , 2 , 1 , 0 , 32 , 128 ]
522+ elif self .block_nums == 24 :
523+ # qwen-1.8b
524+ self .blocks = [QWEN18Block (self .blocks_ [i ], i , self .hidden_size , self .final_layernorm_ if i == len (self .blocks_ ) - 1 else None ) for i in range (self .block_nums )]
525+ self .past_kv_shape = [24 , 2 , 1 , 0 , 16 , 128 ]
488526 # some config for export
489- self .past_kv_shape = [32 , 2 , 1 , 0 , 32 , 128 ]
490527 self .block_dynamic_axes = {
491528 "inputs_embeds" : { 0 : "seq_len" },
492529 "attention_mask" : { 2 : "seq_len" , 3 : "seq_len" },
@@ -594,6 +631,7 @@ def get_position_ids(self) -> torch.Tensor:
594631 'chatglm3-6b' : Chatglm3_6b ,
595632 'codegeex2-6b' : Chatglm2_6b ,
596633 'Qwen-7B-Chat' : Qwen_7b_Chat ,
634+ 'Qwen-1_8B-Chat' : Qwen_7b_Chat ,
597635 'Baichuan2-7B-Chat' : Llama2_7b_Chat ,
598636 'Llama-2-7b-chat-ms' : Llama2_7b_Chat
599637 }
0 commit comments