From 514e77fc3514d37545599336099e43d185447fb7 Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Fri, 28 Nov 2025 14:00:32 +0800 Subject: [PATCH] Fix llama4_rope_with_position_map to support partial rotary factor --- .../frontend/nn/llm/position_embedding.py | 129 +++++++++++++----- 1 file changed, 93 insertions(+), 36 deletions(-) diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index ee2a356299f1..e2a7801adda4 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -117,7 +117,11 @@ def rope_freq_llama4( # pylint: disable=too-many-arguments,too-many-locals smoothed_freq_var = tir.Var("smoothed_freq", "float32") cos_freq = tir.cos(smoothed_freq_var).astype(dtype) sin_freq = tir.sin(smoothed_freq_var).astype(dtype) - return cos_freq, sin_freq, {smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq} + return ( + cos_freq, + sin_freq, + {smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq}, + ) def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals @@ -147,7 +151,11 @@ def rope_freq_llama3( # pylint: disable=too-many-arguments,too-many-locals smoothed_freq_var = tir.Var("smoothed_freq", "float32") cos_freq = tir.cos(smoothed_freq_var).astype(dtype) sin_freq = tir.sin(smoothed_freq_var).astype(dtype) - return cos_freq, sin_freq, {smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq} + return ( + cos_freq, + sin_freq, + {smoothed_freq_var: smoothed_freq, orig_freq_var: orig_freq}, + ) def rope_freq_longrope( # pylint: disable=too-many-arguments @@ -285,7 +293,7 @@ def switch_rope_freq_func(rope_scaling: Dict[str, Any]) -> Callable: beta_slow=rope_scaling["beta_slow"], inv_theta_log_scale=inv_theta_log_scale, ) - raise ValueError(f'Unsupported RoPE scaling type: {rope_scaling["rope_type"]}') + raise ValueError(f"Unsupported RoPE scaling type: {rope_scaling['rope_type']}") # mypy: disable-error-code="attr-defined" @@ -580,7 +588,10 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals # long factors is the first half, short factors is the second half long_factors = T.Buffer((rotary_dim // 2,), "float32", data=ext_factors.data) short_factors = T.Buffer( - (rotary_dim // 2,), "float32", data=ext_factors.data, elem_offset=(rotary_dim // 2) + (rotary_dim // 2,), + "float32", + data=ext_factors.data, + elem_offset=(rotary_dim // 2), ) if seq_len > original_max_position_embeddings: @@ -697,6 +708,10 @@ def llama4_rope_with_position_map( # pylint: disable=too-many-arguments rotary_dim = head_dim scale = tir.const(scale, "float32") is_longrope_scaling = rope_scaling.get("rope_type") == "longrope" + if is_longrope_scaling and "original_max_position_embeddings" in rope_scaling: + original_max_position_embeddings = rope_scaling["original_max_position_embeddings"] + else: + original_max_position_embeddings = 0 def _rope( # pylint: disable=too-many-arguments x: T.Buffer, @@ -780,7 +795,7 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals var_q: T.handle, var_k: T.handle, var_v: T.handle, - ext_factors: T.Buffer((rotary_dim // 2,), "float32"), # type: ignore + ext_factors: T.Buffer((rotary_dim,), "float32"), # type: ignore ): T.func_attr( { @@ -797,37 +812,79 @@ def fused_rope_longrope_scaling( # pylint: disable=too-many-locals position_map = T.match_buffer( var_position_map, (seq_len,), "int32", elem_offset=position_map_elem_offset ) - for iters in T.grid(seq_len, fused_heads, head_dim): - with T.sblock("llama_fused_rope"): - s, h, d = T.axis.remap("SSS", iters) - if h < num_q_heads: - q[s, h, d] = T.if_then_else( - d < rotary_dim, - _rope( - qkv, - s, - h, - d, - position_map[s], - ext_factors if is_longrope_scaling else None, - ), - qkv[s, h, d], - ) - elif h < num_q_heads + num_kv_heads: - k[s, h - num_q_heads, d] = T.if_then_else( - d < rotary_dim, - _rope( - qkv, - s, - h, - d, - position_map[s], - ext_factors if is_longrope_scaling else None, - ), - qkv[s, h, d], - ) - else: - v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] + # long factors is the first half, short factors is the second half + long_factors = T.Buffer((rotary_dim // 2,), "float32", data=ext_factors.data) + short_factors = T.Buffer( + (rotary_dim // 2,), + "float32", + data=ext_factors.data, + elem_offset=(rotary_dim // 2), + ) + + if seq_len > original_max_position_embeddings: + for iters in T.grid(seq_len, fused_heads, head_dim): + with T.sblock("llama_fused_rope"): + s, h, d = T.axis.remap("SSS", iters) + if h < num_q_heads: + q[s, h, d] = T.if_then_else( + d < rotary_dim, + _rope( + qkv, + s, + h, + d, + position_map[s], + long_factors if is_longrope_scaling else None, + ), + qkv[s, h, d], + ) + elif h < num_q_heads + num_kv_heads: + k[s, h - num_q_heads, d] = T.if_then_else( + d < rotary_dim, + _rope( + qkv, + s, + h, + d, + position_map[s], + long_factors if is_longrope_scaling else None, + ), + qkv[s, h, d], + ) + else: + v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] + else: + for iters in T.grid(seq_len, fused_heads, head_dim): + with T.sblock("llama_fused_rope"): + s, h, d = T.axis.remap("SSS", iters) + if h < num_q_heads: + q[s, h, d] = T.if_then_else( + d < rotary_dim, + _rope( + qkv, + s, + h, + d, + position_map[s], + short_factors if is_longrope_scaling else None, + ), + qkv[s, h, d], + ) + elif h < num_q_heads + num_kv_heads: + k[s, h - num_q_heads, d] = T.if_then_else( + d < rotary_dim, + _rope( + qkv, + s, + h, + d, + position_map[s], + short_factors if is_longrope_scaling else None, + ), + qkv[s, h, d], + ) + else: + v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] if is_longrope_scaling: return fused_rope_longrope_scaling