Skip to content

Commit 47b2fa0

Browse files
committed
Updated load_weight for Siglip2VisionTransformer
Signed-off-by: Oscar Gonzalez <ogonzal6@alumni.jh.edu>
1 parent 9e4d274 commit 47b2fa0

File tree

1 file changed

+27
-133
lines changed

1 file changed

+27
-133
lines changed

vllm/model_executor/models/isaac.py

Lines changed: 27 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818
from transformers import PretrainedConfig, Qwen3Config
1919
from transformers.image_processing_utils import BatchFeature
2020
from transformers.tokenization_utils import TensorType
21-
from transformers.models.siglip2.modeling_siglip2 import (
22-
Siglip2MLP,
23-
)
2421
from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig
2522

2623
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -30,6 +27,7 @@
3027
AutoWeightsLoader,
3128
_merge_multimodal_embeddings,
3229
maybe_prefix,
30+
init_vllm_registered_model,
3331
)
3432
from vllm.model_executor.models.qwen3 import Qwen3ForCausalLM
3533
from vllm.model_executor.models.module_mapping import MultiModelKeys
@@ -54,6 +52,15 @@
5452
SupportsPP,
5553
)
5654

55+
from vllm.model_executor.model_loader.weight_utils import (
56+
default_weight_loader,
57+
)
58+
from vllm.model_executor.models.siglip2navit import Siglip2Encoder
59+
from vllm.attention.backends.registry import _Backend
60+
from vllm.model_executor.layers.quantization import QuantizationConfig
61+
62+
from vllm.model_executor.layers.linear import ReplicatedLinear
63+
5764
# ===== TensorStream Compatibility Layer for Isaac MRoPE =====
5865
# Minimal implementation of TensorStream classes needed for Isaac's 3D positional encoding
5966

@@ -316,9 +323,10 @@ def __init__(self, config: PixelShuffleSiglip2VisionConfig):
316323
self.embed_dim = config.hidden_size
317324
self.patch_size = config.patch_size
318325

319-
self.patch_embedding = nn.Linear(
320-
in_features=config.num_channels * self.patch_size * self.patch_size,
321-
out_features=self.embed_dim,
326+
self.patch_embedding = ReplicatedLinear(
327+
input_size=config.num_channels * self.patch_size * self.patch_size,
328+
output_size=self.embed_dim,
329+
return_bias=False,
322330
)
323331

324332
self.num_patches = config.num_patches
@@ -1058,37 +1066,10 @@ def get_replacement_isaac(item_idx: int):
10581066
)
10591067
]
10601068

1061-
from vllm.model_executor.model_loader.weight_utils import (
1062-
default_weight_loader,
1063-
maybe_remap_kv_scale_name,
1064-
)
1065-
from vllm.model_executor.models.utils import is_pp_missing_parameter
1066-
from vllm.model_executor.models.siglip2navit import Siglip2VisionEmbeddings, Siglip2Encoder
1067-
from vllm.attention.backends.registry import _Backend
1068-
from vllm.model_executor.layers.quantization import QuantizationConfig
1069-
1070-
class Siglip2VisionTransformer(nn.Module, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
1071-
):
1072-
1073-
is_pooling_model = True
1074-
1075-
merge_by_field_config = True
1076-
1077-
packed_modules_mapping = {
1078-
"qkv_proj": [
1079-
"q_proj",
1080-
"k_proj",
1081-
"v_proj",
1082-
],
1083-
"gate_up_proj": [
1084-
"gate_proj",
1085-
"up_proj",
1086-
],
1087-
}
1088-
1069+
class Siglip2VisionTransformer(nn.Module):
10891070
def __init__(
10901071
self,
1091-
config,
1072+
config: PixelShuffleSiglip2VisionConfig,
10921073
quant_config: QuantizationConfig | None = None,
10931074
prefix: str = "",
10941075
use_data_parallel: bool = False,
@@ -1151,64 +1132,28 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
11511132
("qkv_proj", "q_proj", "q"),
11521133
("qkv_proj", "k_proj", "k"),
11531134
("qkv_proj", "v_proj", "v"),
1154-
("gate_up_proj", "gate_proj", 0),
1155-
("gate_up_proj", "up_proj", 1),
11561135
]
1157-
params_dict = dict(self.named_parameters(remove_duplicate=False))
1136+
params_dict = dict(self.named_parameters())
11581137
loaded_params: set[str] = set()
1138+
11591139
for name, loaded_weight in weights:
1160-
if "rotary_emb.inv_freq" in name:
1161-
continue
1162-
if self.quant_config is not None and (
1163-
scale_name := self.quant_config.get_cache_scale(name)
1164-
):
1165-
# Loading kv cache quantization scales
1166-
param = params_dict[scale_name]
1167-
weight_loader = getattr(param, "weight_loader", default_weight_loader)
1168-
loaded_weight = (
1169-
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
1170-
)
1171-
weight_loader(param, loaded_weight)
1172-
loaded_params.add(scale_name)
1173-
continue
11741140
for param_name, weight_name, shard_id in stacked_params_mapping:
11751141
if weight_name not in name:
11761142
continue
11771143
name = name.replace(weight_name, param_name)
1178-
# Skip loading extra bias for GPTQ models.
1179-
if name.endswith(".bias") and name not in params_dict:
1180-
continue
1181-
if is_pp_missing_parameter(name, self):
1182-
continue
1183-
if name.endswith("scale"):
1184-
# Remapping the name of FP8 kv-scale.
1185-
name = maybe_remap_kv_scale_name(name, params_dict)
1186-
if name is None:
1187-
continue
1144+
11881145
param = params_dict[name]
1189-
weight_loader = getattr(param, "weight_loader", default_weight_loader)
1190-
if weight_loader == default_weight_loader:
1191-
weight_loader(param, loaded_weight)
1192-
else:
1193-
weight_loader(param, loaded_weight, shard_id)
1146+
weight_loader = param.weight_loader
1147+
weight_loader(param, loaded_weight, shard_id)
11941148
break
11951149
else:
1196-
# Skip loading extra bias for GPTQ models.
1197-
if name.endswith(".bias") and name not in params_dict:
1198-
continue
1199-
# Remapping the name of FP8 kv-scale.
1200-
name = maybe_remap_kv_scale_name(name, params_dict)
1201-
if name is None:
1202-
continue
1203-
if is_pp_missing_parameter(name, self):
1204-
continue
1205-
print(f"qwen2: name={name}")
12061150
param = params_dict[name]
12071151
weight_loader = getattr(param, "weight_loader", default_weight_loader)
12081152
weight_loader(param, loaded_weight)
12091153
loaded_params.add(name)
12101154
return loaded_params
12111155

1156+
12121157
@MULTIMODAL_REGISTRY.register_processor(
12131158
IsaacMultiModalProcessor,
12141159
info=IsaacProcessingInfo,
@@ -1217,6 +1162,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
12171162
class IsaacForConditionalGeneration(
12181163
Qwen3ForCausalLM, SupportsMultiModal, SupportsLoRA, SupportsPP, SupportsMRoPE
12191164
):
1165+
12201166
packed_modules_mapping = {
12211167
"qkv_proj": [
12221168
"q_proj",
@@ -1230,7 +1176,7 @@ class IsaacForConditionalGeneration(
12301176
}
12311177

12321178
supports_encoder_tp_data = True
1233-
1179+
12341180
# To ensure correct weight loading and mapping.
12351181
hf_to_vllm_mapper = WeightsMapper(
12361182
orig_to_new_prefix={
@@ -1261,14 +1207,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
12611207

12621208
# Initialize the parent class with updated config
12631209
super().__init__(vllm_config=vllm_config, prefix=prefix)
1264-
1210+
12651211
# Create the language model module to match checkpoint structure
12661212
self.language_model = nn.ModuleDict({
12671213
"embed_tokens": self.model.embed_tokens,
12681214
"layers": self.model.layers,
12691215
"norm": self.model.norm
12701216
})
1271-
1217+
12721218
config.vision_config.preserve_original_pe = True
12731219
config.vision_config.use_rope = False
12741220
config.vision_config.hidden_stride = config.vision_config.pixel_shuffle_scale_factor
@@ -1431,61 +1377,9 @@ def get_input_embeddings(
14311377

14321378
return inputs_embeds
14331379

1434-
def merge_qkv_weights(
1435-
weights: Iterable[tuple[str, torch.Tensor]]
1436-
) -> Iterable[tuple[str, torch.Tensor]]:
1437-
"""Merge separate Q, K, V projection weights into QKV format."""
1438-
1439-
# Buffer to collect q, k, v weights for each layer
1440-
qkv_buffer = {}
1441-
1442-
for name, tensor in weights:
1443-
# Check if this is a q/k/v projection weight
1444-
if '.q_proj.' in name or '.k_proj.' in name or '.v_proj.' in name:
1445-
# Extract the base name (everything before q/k/v_proj)
1446-
if '.q_proj.' in name:
1447-
base_name = name.replace('.q_proj.', '.qkv_proj.')
1448-
proj_type = 'q'
1449-
elif '.k_proj.' in name:
1450-
base_name = name.replace('.k_proj.', '.qkv_proj.')
1451-
proj_type = 'k'
1452-
else: # v_proj
1453-
base_name = name.replace('.v_proj.', '.qkv_proj.')
1454-
proj_type = 'v'
1455-
1456-
# Store in buffer
1457-
if base_name not in qkv_buffer:
1458-
qkv_buffer[base_name] = {}
1459-
qkv_buffer[base_name][proj_type] = tensor
1460-
1461-
# If we have all three (q, k, v), merge and yield
1462-
if len(qkv_buffer[base_name]) == 3:
1463-
q = qkv_buffer[base_name]['q']
1464-
k = qkv_buffer[base_name]['k']
1465-
v = qkv_buffer[base_name]['v']
1466-
1467-
# Concatenate along dim 0 for weight, dim agnostic for bias
1468-
merged = torch.cat([q, k, v], dim=0)
1469-
yield base_name, merged
1470-
1471-
# Clear buffer
1472-
del qkv_buffer[base_name]
1473-
else:
1474-
# Pass through non-qkv weights unchanged
1475-
yield name, tensor
1476-
1477-
# Check if any incomplete qkv sets remain
1478-
if qkv_buffer:
1479-
raise ValueError(f"Incomplete QKV weights found: {list(qkv_buffer.keys())}")
1480-
1481-
14821380
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
14831381
skip_prefixes = []
1484-
#if self.vision_embedding is None:
1485-
# skip_prefixes.extend(["vision_embedding."])
1486-
1487-
# Usage:
1488-
#weights = self.merge_qkv_weights(weights)
1382+
14891383
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
14901384
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
14911385

0 commit comments

Comments
 (0)