diff --git a/vllm/model_executor/models/openpangu.py b/vllm/model_executor/models/openpangu.py index bf1b7570a882..d97c50b56ee3 100644 --- a/vllm/model_executor/models/openpangu.py +++ b/vllm/model_executor/models/openpangu.py @@ -86,6 +86,29 @@ def check_ffn_act_fn(act_fn: str): ) +def _normalize_rope_scaling_defaults( + rope_scaling_config: dict[str, Any] | None, max_position_embeddings: int +) -> dict[str, Any]: + """Return a DeepSeek-YaRN compatible rope_scaling dict.""" + rope_scaling = dict(rope_scaling_config) if rope_scaling_config is not None else {} + original_scaling_type = rope_scaling.get("type") + + rope_scaling.setdefault("beta_fast", 32) + rope_scaling.setdefault("beta_slow", 1) + rope_scaling.setdefault("factor", 1) + rope_scaling.setdefault("mscale", 1.0) + rope_scaling.setdefault("mscale_all_dim", 1.0) + rope_scaling.setdefault("type", "yarn") + rope_scaling.setdefault("original_max_position_embeddings", max_position_embeddings) + if "rope_type" not in rope_scaling: + if original_scaling_type is not None: + rope_scaling["rope_type"] = original_scaling_type + else: + rope_scaling["rope_type"] = "deepseek_yarn" + + return rope_scaling + + class OpenPanguMLP(nn.Module): def __init__( self, @@ -338,17 +361,10 @@ def __init__( prefix=f"{prefix}.o_proj", ) - # TODO: remove hard coding - rope_scaling = { - "beta_fast": 32, - "beta_slow": 1, - "factor": 1, - "mscale": 1.0, - "mscale_all_dim": 1.0, - "original_max_position_embeddings": max_position_embeddings, - "type": "yarn", - "rope_type": "deepseek_yarn", - } + rope_scaling_config = getattr(config, "rope_scaling", None) + rope_scaling = _normalize_rope_scaling_defaults( + rope_scaling_config, max_position_embeddings + ) self.rotary_emb = get_rope( qk_rope_head_dim, rotary_dim=qk_rope_head_dim, @@ -529,6 +545,10 @@ def _init_rotary_emb( if is_gguf and config.model_type == "PanguEmbedded": is_neox_style = False + rope_scaling = _normalize_rope_scaling_defaults( + rope_scaling, self.max_position_embeddings + ) + self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.head_dim,