diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index e31530109..91ae76594 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -280,6 +280,9 @@ def __init__( "parallel_draft_heads_num_layers": mode_cfg["config"][ "eagle_architecture_config" ]["parallel_draft_heads_num_layers"], + "use_embedding": mode_cfg["config"]["eagle_architecture_config"][ + "use_embedding" + ], } eagle_config_update = { @@ -954,9 +957,11 @@ def _get_eagle_module_state_dict(self): self.rules["fc"](eagle_module.fc) if self.model.eagle_config.use_aux_hidden_state: - self.rules["enorm"](eagle_module.enorm) + if self.model.eagle_config.use_embedding: + self.rules["enorm"](eagle_module.enorm) elif self.model.eagle_config.use_mtp_layernorm: - self.rules["enorm"](eagle_module.enorm) + if self.model.eagle_config.use_embedding: + self.rules["enorm"](eagle_module.enorm) self.rules["hnorm"](eagle_module.hnorm) if self.model.eagle_config.use_last_layernorm: diff --git a/modelopt/torch/speculative/eagle/default_config.py b/modelopt/torch/speculative/eagle/default_config.py index 415b8373f..1edfe034b 100644 --- a/modelopt/torch/speculative/eagle/default_config.py +++ b/modelopt/torch/speculative/eagle/default_config.py @@ -47,6 +47,7 @@ "use_mtp_layernorm": False, "parallel_draft_step": 1, "parallel_draft_heads_num_layers": 1, - "has_lm_head": False, + "has_lm_head": True, + "use_embedding": True, "head_dim": 128, } diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 651fca587..710407633 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -139,6 +139,7 @@ def dict_to_config( "parallel_draft_heads_num_layers" ) config.has_lm_head = architecture_config.get("has_lm_head") + config.use_embedding = architecture_config.get("use_embedding") return config @@ -365,17 +366,23 @@ def __init__( self._num_aux_hidden_states = len(self.config.eagle_aux_hidden_state_layer_ids) if self._num_aux_hidden_states > 0: - self.enorm = TENorm(config, config.hidden_size, config.layernorm_epsilon) - self._embeddings = None + if config.use_embedding: + self.enorm = TENorm(config, config.hidden_size, config.layernorm_epsilon) + self._embeddings = None elif self.config.use_mtp_layernorm: - self.enorm = TENorm(config, config.hidden_size, config.layernorm_epsilon) + if config.use_embedding: + self.enorm = TENorm(config, config.hidden_size, config.layernorm_epsilon) self.hnorm = TENorm(config, config.hidden_size, config.layernorm_epsilon) device = "cpu" if config.use_cpu_initialization else torch.cuda.current_device() # EAGLE-3 uses aux_hidden_states (usually >= 3); otherwise EAGLE-1 fc_input_size_multiplier = ( - self._num_aux_hidden_states if self._num_aux_hidden_states > 0 else 2 + self._num_aux_hidden_states + if self._num_aux_hidden_states > 0 + else 2 + if config.use_embedding + else 1 ) # This linear was previously a ColumnParallelLinear. We changed it to a normal linear @@ -408,27 +415,29 @@ def __init__( last_layer = self.decoder.layers[-1] last_layer.register_forward_hook(self._eagle3_layer_forward_hook) - # The first EAGLE3 layer needs to be specialized. - layer = self.decoder.layers[0] - self_attention = layer.self_attention - if not isinstance(self_attention, SelfAttention): - raise ValueError("EAGLE-3 only support SelfAttention (MHA, GQA).") - - # EAGLE-3's first attention require [input_layernorm_output, aux_hidden_states] - self_attention.register_forward_pre_hook(self._eagle3_attention_forward_pre_hook) - - # EAGLE-3's first layer reduces hidden_states from 2h to h. - self_attention.linear_qkv = tensor_parallel.ColumnParallelLinear( - self_attention.config.hidden_size * 2, - self_attention.query_projection_size + 2 * self_attention.kv_projection_size, - config=self_attention.config, - init_method=self_attention.config.init_method, - gather_output=False, - bias=self_attention.config.add_bias_linear or self_attention.config.add_qkv_bias, - skip_bias_add=False, - is_expert=False, - tp_comm_buffer_name="qkv", - ) + if self.config.use_embedding: + # The first EAGLE3 layer needs to be specialized. + layer = self.decoder.layers[0] + self_attention = layer.self_attention + if not isinstance(self_attention, SelfAttention): + raise ValueError("EAGLE-3 only support SelfAttention (MHA, GQA).") + + # EAGLE-3's first attention require [input_layernorm_output, embeddings] + self_attention.register_forward_pre_hook(self._eagle3_attention_forward_pre_hook) + + # EAGLE-3's first layer reduces hidden_states from 2h to h. + self_attention.linear_qkv = tensor_parallel.ColumnParallelLinear( + self_attention.config.hidden_size * 2, + self_attention.query_projection_size + 2 * self_attention.kv_projection_size, + config=self_attention.config, + init_method=self_attention.config.init_method, + gather_output=False, + bias=self_attention.config.add_bias_linear + or self_attention.config.add_qkv_bias, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name="qkv", + ) if self.config.draft_vocab_size != self.config.vocab_size: # Need an extra lm_head for eagle module since vocab size is reduced. @@ -508,7 +517,7 @@ def _eagle3_attention_forward_pre_hook(self, module, input_layernorm_output): def forward( self, - embeddings: torch.Tensor, + embeddings: torch.Tensor | None, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rotary_pos_emb: torch.Tensor = None, @@ -525,19 +534,27 @@ def forward( if self.config.sequence_parallel: rotary_seq_len *= self.config.tensor_model_parallel_size + if self.config.use_embedding: + assert embeddings is not None, "embeddings cannot be None when use_embedding is True" + if self.config.use_mtp_layernorm: - embeddings = self.enorm(embeddings) + if self.config.use_embedding: + embeddings = self.enorm(embeddings) hidden_states = self.hnorm(hidden_states) # EAGLE-1 uses [s, b, h] input but EAGLE-3 uses [s, b, 2h] input if self._num_aux_hidden_states == 0: - # [s, b, 2h] - decoder_input = torch.cat((embeddings, hidden_states), dim=-1) + if self.config.use_embedding: + # [s, b, 2h] + decoder_input = torch.cat((embeddings, hidden_states), dim=-1) + else: + decoder_input = hidden_states decoder_input = self.fc(decoder_input)[0] else: # EAGLE-3 forward # EAGLE-3 uses self.fc outside eagle_module forward to convert hidden_states from [s, b, 3h] - self._embeddings = self.enorm(embeddings) + if self.config.use_embedding: + self._embeddings = self.enorm(embeddings) decoder_input = hidden_states if rotary_pos_emb is None: @@ -708,7 +725,7 @@ def modify( if self.eagle_config.position_embedding_type not in ["rope", "yarn"]: raise ValueError("For EAGLE, only RoPE or YaRN embedding are supported") - if not self.pre_process and self.post_process: + if not self.pre_process and self.post_process and self.eagle_config.use_embedding: self.embedding = EagleLanguageModelEmbedding( config=self.config, vocab_size=self.vocab_size, @@ -817,10 +834,11 @@ def _get_eagle_module_inputs( eagle_inputs["input_ids"] = padded_input_ids eagle_inputs["position_ids"] = position_ids - eagle_inputs["embedding"] = self.embedding( - input_ids=eagle_inputs["input_ids"], - position_ids=eagle_inputs["position_ids"], - ) + if self.eagle_config.use_embedding: + eagle_inputs["embedding"] = self.embedding( + input_ids=eagle_inputs["input_ids"], + position_ids=eagle_inputs["position_ids"], + ) eagle_inputs["hidden_states"] = hidden_states eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, ttt_step) @@ -919,14 +937,13 @@ def _base_model_forward( def _eagle_forward( self, eagle_inputs, - output_weight, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, inference_context: StaticInferenceContext | None = None, extra_block_kwargs: dict | None = None, ): eagle_hidden_states, eagle_hidden_states_pre_final_layernorm = self.eagle_module( - eagle_inputs["embedding"], + eagle_inputs["embedding"] if self.eagle_config.use_embedding else None, eagle_inputs["hidden_states"], eagle_inputs["attention_mask"], eagle_inputs["rotary_pos_emb"], @@ -940,24 +957,22 @@ def _eagle_forward( if inference_context is not None: inference_context.sequence_len_offset += eagle_inputs["input_ids"].shape[1] + return eagle_hidden_states, eagle_hidden_states_pre_final_layernorm + + def _get_eagle_logits(self, eagle_hidden_states, output_weight): if hasattr(self.eagle_module, "eagle_output_layer"): eagle_logits, _ = self.eagle_module.eagle_output_layer(eagle_hidden_states) else: eagle_logits, _ = self.output_layer(eagle_hidden_states, weight=output_weight) + draft_logits = [eagle_logits] if self.eagle_config.parallel_draft_step > 1: # Get additional draft logits from parallel draft heads - draft_logits_list = [eagle_logits] for draft_head in self.eagle_module.parallel_draft_heads: - draft_logits, _ = draft_head(eagle_hidden_states) - draft_logits_list.append(draft_logits) - eagle_logits = torch.cat(draft_logits_list, dim=0) - - return ( - eagle_hidden_states, - eagle_logits, - eagle_hidden_states_pre_final_layernorm, - ) + parallel_logits, _ = draft_head(eagle_hidden_states) + draft_logits.append(parallel_logits) + + return draft_logits def forward( self, @@ -1074,15 +1089,16 @@ def forward( ttt_step=ttt_step, ) - _, eagle_logits, eagle_module_input_hidden_states = self._eagle_forward( + eagle_hidden_states, eagle_module_input_hidden_states = self._eagle_forward( eagle_inputs, - output_weight, inference_params=inference_params, packed_seq_params=packed_seq_params, inference_context=eagle_inference_context, **(extra_block_kwargs or {}), ) + eagle_logits = self._get_eagle_logits(eagle_hidden_states, output_weight) + if self.config.sequence_parallel: eagle_module_input_hidden_states = gather_from_sequence_parallel_region( eagle_module_input_hidden_states @@ -1112,7 +1128,7 @@ def forward( return logits_sbh.transpose(0, 1).contiguous() for i in range(self.eagle_config.parallel_draft_step): - eagle_logit = eagle_logits[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] + eagle_logit = eagle_logits[i] if i > 0: loss_ = self._compute_eagle_loss( logits_sbh[i:], labels[:, i:], eagle_logit[:-i] @@ -1129,9 +1145,7 @@ def forward( gathered_base_logits = gather_from_tensor_model_parallel_region(logits_sbh) base_top1 = gathered_base_logits.transpose(0, 1).argmax(dim=-1) for i in range(self.eagle_config.parallel_draft_step): - gathered_logits = gather_from_tensor_model_parallel_region( - eagle_logits[i * input_ids.shape[1] : (i + 1) * input_ids.shape[1]] - ) + gathered_logits = gather_from_tensor_model_parallel_region(eagle_logits[i]) gathered_logits = gathered_logits[ttt_step : -(1 + i)] eagle_top1 = gathered_logits.transpose(0, 1).argmax(dim=-1) if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: @@ -1378,6 +1392,7 @@ def pseudo_speculative_generate( hidden_states = hidden_states[:seq_len, :, :] draft_tokens = [] + draft_hidden_states = [] for _ in range(steps): padded_eagle_ids, seq_len, padded_hidden_states = right_padding( eagle_ids, hidden_states @@ -1389,51 +1404,58 @@ def pseudo_speculative_generate( ) eagle_inputs = {} - eagle_inputs["input_ids"] = padded_eagle_ids - embeddings = self.embedding( - input_ids=padded_eagle_ids, - position_ids=eagle_position_ids, - ) - eagle_inputs["embedding"] = embeddings + if self.eagle_config.use_embedding: + eagle_inputs["embedding"] = self.embedding( + input_ids=padded_eagle_ids, + position_ids=eagle_position_ids, + ) + eagle_inputs["hidden_states"] = padded_hidden_states eagle_inputs["attention_mask"] = eagle_attention_mask # [TODO] (chenhany): let the module compute itself eagle_inputs["rotary_pos_emb"] = None - _, eagle_logits, eagle_next_hidden_states_input = self._eagle_forward( + eagle_hidden_states, eagle_next_hidden_states_input = self._eagle_forward( eagle_inputs, - output_weight, ) - if self.eagle_config.parallel_draft_step > 1: - parallel_logits = [ - eagle_logits[ - padded_eagle_ids.shape[-1] * i + seq_len - 1 : padded_eagle_ids.shape[-1] - * i - + seq_len + if self.eagle_config.use_embedding: + eagle_logits = self._get_eagle_logits(eagle_hidden_states, output_weight) + + if self.eagle_config.parallel_draft_step > 1: + parallel_logits = [ + eagle_logits[i][seq_len - 1 : seq_len, :, :] + for i in range(1, self.eagle_config.parallel_draft_step) ] - for i in range(1, self.eagle_config.parallel_draft_step) - ] - eagle_logits = eagle_logits[:seq_len, :, :] + eagle_logits = eagle_logits[0][:seq_len, :, :] if self.config.sequence_parallel: + eagle_hidden_states = gather_from_sequence_parallel_region(eagle_hidden_states) eagle_next_hidden_states_input = gather_from_sequence_parallel_region( eagle_next_hidden_states_input ) + eagle_hidden_states = eagle_hidden_states[:seq_len, :, :] eagle_next_hidden_states_input = eagle_next_hidden_states_input[:seq_len, :, :] - draft_token = ( - gather_from_tensor_model_parallel_region(eagle_logits)[-1:, :, :] - .argmax(dim=-1) - .transpose(0, 1) - ) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - draft_token += self.eagle_module.d2t[draft_token] + if self.eagle_config.use_embedding: + draft_token = ( + gather_from_tensor_model_parallel_region(eagle_logits)[-1:, :, :] + .argmax(dim=-1) + .transpose(0, 1) + ) + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + draft_token += self.eagle_module.d2t[draft_token] - draft_tokens.append(draft_token) + draft_tokens.append(draft_token) + + eagle_ids = torch.cat((eagle_ids, draft_token), dim=-1) + else: + draft_hidden_states.append(eagle_hidden_states[-1:, :, :]) + eagle_ids = torch.cat( + (eagle_ids, eagle_ids[:, -1:]), dim=-1 + ) # dummy, only need this for attention mask - eagle_ids = torch.cat((eagle_ids, draft_token), dim=-1) hidden_states = torch.cat( ( hidden_states, @@ -1442,17 +1464,36 @@ def pseudo_speculative_generate( dim=0, ) - draft_tokens = torch.cat(draft_tokens, dim=-1) - if self.eagle_config.parallel_draft_step > 1: - parallel_logits = torch.cat(parallel_logits, dim=0) - parallel_tokens = ( - (gather_from_tensor_model_parallel_region(parallel_logits)) + if self.eagle_config.use_embedding: + draft_tokens = torch.cat(draft_tokens, dim=-1) + if self.eagle_config.parallel_draft_step > 1: + parallel_logits = torch.cat(parallel_logits, dim=0) + parallel_tokens = ( + (gather_from_tensor_model_parallel_region(parallel_logits)) + .argmax(dim=-1) + .transpose(0, 1) + ) + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: + parallel_tokens += self.eagle_module.d2t[parallel_tokens] + draft_tokens = torch.cat((draft_tokens, parallel_tokens), dim=-1) + else: + eagle_hidden_states = torch.cat(draft_hidden_states, dim=0) + eagle_logits = self._get_eagle_logits(eagle_hidden_states, output_weight) + draft_logits = [eagle_logits[0]] + draft_logits.extend( + [ + eagle_logits[i][-1:, :, :] + for i in range(1, self.eagle_config.parallel_draft_step) + ] + ) + draft_logits = torch.cat((draft_logits), dim=0) + draft_tokens = ( + gather_from_tensor_model_parallel_region(draft_logits) .argmax(dim=-1) .transpose(0, 1) ) if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - parallel_tokens += self.eagle_module.d2t[parallel_tokens] - draft_tokens = torch.cat((draft_tokens, parallel_tokens), dim=-1) + draft_tokens += self.eagle_module.d2t[draft_tokens] return base_token, draft_tokens