1818from transformers import PretrainedConfig , Qwen3Config
1919from transformers .image_processing_utils import BatchFeature
2020from transformers .tokenization_utils import TensorType
21- from transformers .models .siglip2 .modeling_siglip2 import (
22- Siglip2MLP ,
23- )
2421from transformers .models .siglip2 .configuration_siglip2 import Siglip2VisionConfig
2522
2623from vllm .multimodal import MULTIMODAL_REGISTRY
3027 AutoWeightsLoader ,
3128 _merge_multimodal_embeddings ,
3229 maybe_prefix ,
30+ init_vllm_registered_model ,
3331)
3432from vllm .model_executor .models .qwen3 import Qwen3ForCausalLM
3533from vllm .model_executor .models .module_mapping import MultiModelKeys
5452 SupportsPP ,
5553)
5654
55+ from vllm .model_executor .model_loader .weight_utils import (
56+ default_weight_loader ,
57+ )
58+ from vllm .model_executor .models .siglip2navit import Siglip2Encoder
59+ from vllm .attention .backends .registry import _Backend
60+ from vllm .model_executor .layers .quantization import QuantizationConfig
61+
62+ from vllm .model_executor .layers .linear import ReplicatedLinear
63+
5764# ===== TensorStream Compatibility Layer for Isaac MRoPE =====
5865# Minimal implementation of TensorStream classes needed for Isaac's 3D positional encoding
5966
@@ -316,9 +323,10 @@ def __init__(self, config: PixelShuffleSiglip2VisionConfig):
316323 self .embed_dim = config .hidden_size
317324 self .patch_size = config .patch_size
318325
319- self .patch_embedding = nn .Linear (
320- in_features = config .num_channels * self .patch_size * self .patch_size ,
321- out_features = self .embed_dim ,
326+ self .patch_embedding = ReplicatedLinear (
327+ input_size = config .num_channels * self .patch_size * self .patch_size ,
328+ output_size = self .embed_dim ,
329+ return_bias = False ,
322330 )
323331
324332 self .num_patches = config .num_patches
@@ -1058,37 +1066,10 @@ def get_replacement_isaac(item_idx: int):
10581066 )
10591067 ]
10601068
1061- from vllm .model_executor .model_loader .weight_utils import (
1062- default_weight_loader ,
1063- maybe_remap_kv_scale_name ,
1064- )
1065- from vllm .model_executor .models .utils import is_pp_missing_parameter
1066- from vllm .model_executor .models .siglip2navit import Siglip2VisionEmbeddings , Siglip2Encoder
1067- from vllm .attention .backends .registry import _Backend
1068- from vllm .model_executor .layers .quantization import QuantizationConfig
1069-
1070- class Siglip2VisionTransformer (nn .Module , SupportsMultiModal , SupportsLoRA , SupportsPP , SupportsMRoPE
1071- ):
1072-
1073- is_pooling_model = True
1074-
1075- merge_by_field_config = True
1076-
1077- packed_modules_mapping = {
1078- "qkv_proj" : [
1079- "q_proj" ,
1080- "k_proj" ,
1081- "v_proj" ,
1082- ],
1083- "gate_up_proj" : [
1084- "gate_proj" ,
1085- "up_proj" ,
1086- ],
1087- }
1088-
1069+ class Siglip2VisionTransformer (nn .Module ):
10891070 def __init__ (
10901071 self ,
1091- config ,
1072+ config : PixelShuffleSiglip2VisionConfig ,
10921073 quant_config : QuantizationConfig | None = None ,
10931074 prefix : str = "" ,
10941075 use_data_parallel : bool = False ,
@@ -1151,64 +1132,28 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
11511132 ("qkv_proj" , "q_proj" , "q" ),
11521133 ("qkv_proj" , "k_proj" , "k" ),
11531134 ("qkv_proj" , "v_proj" , "v" ),
1154- ("gate_up_proj" , "gate_proj" , 0 ),
1155- ("gate_up_proj" , "up_proj" , 1 ),
11561135 ]
1157- params_dict = dict (self .named_parameters (remove_duplicate = False ))
1136+ params_dict = dict (self .named_parameters ())
11581137 loaded_params : set [str ] = set ()
1138+
11591139 for name , loaded_weight in weights :
1160- if "rotary_emb.inv_freq" in name :
1161- continue
1162- if self .quant_config is not None and (
1163- scale_name := self .quant_config .get_cache_scale (name )
1164- ):
1165- # Loading kv cache quantization scales
1166- param = params_dict [scale_name ]
1167- weight_loader = getattr (param , "weight_loader" , default_weight_loader )
1168- loaded_weight = (
1169- loaded_weight if loaded_weight .dim () == 0 else loaded_weight [0 ]
1170- )
1171- weight_loader (param , loaded_weight )
1172- loaded_params .add (scale_name )
1173- continue
11741140 for param_name , weight_name , shard_id in stacked_params_mapping :
11751141 if weight_name not in name :
11761142 continue
11771143 name = name .replace (weight_name , param_name )
1178- # Skip loading extra bias for GPTQ models.
1179- if name .endswith (".bias" ) and name not in params_dict :
1180- continue
1181- if is_pp_missing_parameter (name , self ):
1182- continue
1183- if name .endswith ("scale" ):
1184- # Remapping the name of FP8 kv-scale.
1185- name = maybe_remap_kv_scale_name (name , params_dict )
1186- if name is None :
1187- continue
1144+
11881145 param = params_dict [name ]
1189- weight_loader = getattr (param , "weight_loader" , default_weight_loader )
1190- if weight_loader == default_weight_loader :
1191- weight_loader (param , loaded_weight )
1192- else :
1193- weight_loader (param , loaded_weight , shard_id )
1146+ weight_loader = param .weight_loader
1147+ weight_loader (param , loaded_weight , shard_id )
11941148 break
11951149 else :
1196- # Skip loading extra bias for GPTQ models.
1197- if name .endswith (".bias" ) and name not in params_dict :
1198- continue
1199- # Remapping the name of FP8 kv-scale.
1200- name = maybe_remap_kv_scale_name (name , params_dict )
1201- if name is None :
1202- continue
1203- if is_pp_missing_parameter (name , self ):
1204- continue
1205- print (f"qwen2: name={ name } " )
12061150 param = params_dict [name ]
12071151 weight_loader = getattr (param , "weight_loader" , default_weight_loader )
12081152 weight_loader (param , loaded_weight )
12091153 loaded_params .add (name )
12101154 return loaded_params
12111155
1156+
12121157@MULTIMODAL_REGISTRY .register_processor (
12131158 IsaacMultiModalProcessor ,
12141159 info = IsaacProcessingInfo ,
@@ -1217,6 +1162,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
12171162class IsaacForConditionalGeneration (
12181163 Qwen3ForCausalLM , SupportsMultiModal , SupportsLoRA , SupportsPP , SupportsMRoPE
12191164):
1165+
12201166 packed_modules_mapping = {
12211167 "qkv_proj" : [
12221168 "q_proj" ,
@@ -1230,7 +1176,7 @@ class IsaacForConditionalGeneration(
12301176 }
12311177
12321178 supports_encoder_tp_data = True
1233-
1179+
12341180 # To ensure correct weight loading and mapping.
12351181 hf_to_vllm_mapper = WeightsMapper (
12361182 orig_to_new_prefix = {
@@ -1261,14 +1207,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"):
12611207
12621208 # Initialize the parent class with updated config
12631209 super ().__init__ (vllm_config = vllm_config , prefix = prefix )
1264-
1210+
12651211 # Create the language model module to match checkpoint structure
12661212 self .language_model = nn .ModuleDict ({
12671213 "embed_tokens" : self .model .embed_tokens ,
12681214 "layers" : self .model .layers ,
12691215 "norm" : self .model .norm
12701216 })
1271-
1217+
12721218 config .vision_config .preserve_original_pe = True
12731219 config .vision_config .use_rope = False
12741220 config .vision_config .hidden_stride = config .vision_config .pixel_shuffle_scale_factor
@@ -1431,61 +1377,9 @@ def get_input_embeddings(
14311377
14321378 return inputs_embeds
14331379
1434- def merge_qkv_weights (
1435- weights : Iterable [tuple [str , torch .Tensor ]]
1436- ) -> Iterable [tuple [str , torch .Tensor ]]:
1437- """Merge separate Q, K, V projection weights into QKV format."""
1438-
1439- # Buffer to collect q, k, v weights for each layer
1440- qkv_buffer = {}
1441-
1442- for name , tensor in weights :
1443- # Check if this is a q/k/v projection weight
1444- if '.q_proj.' in name or '.k_proj.' in name or '.v_proj.' in name :
1445- # Extract the base name (everything before q/k/v_proj)
1446- if '.q_proj.' in name :
1447- base_name = name .replace ('.q_proj.' , '.qkv_proj.' )
1448- proj_type = 'q'
1449- elif '.k_proj.' in name :
1450- base_name = name .replace ('.k_proj.' , '.qkv_proj.' )
1451- proj_type = 'k'
1452- else : # v_proj
1453- base_name = name .replace ('.v_proj.' , '.qkv_proj.' )
1454- proj_type = 'v'
1455-
1456- # Store in buffer
1457- if base_name not in qkv_buffer :
1458- qkv_buffer [base_name ] = {}
1459- qkv_buffer [base_name ][proj_type ] = tensor
1460-
1461- # If we have all three (q, k, v), merge and yield
1462- if len (qkv_buffer [base_name ]) == 3 :
1463- q = qkv_buffer [base_name ]['q' ]
1464- k = qkv_buffer [base_name ]['k' ]
1465- v = qkv_buffer [base_name ]['v' ]
1466-
1467- # Concatenate along dim 0 for weight, dim agnostic for bias
1468- merged = torch .cat ([q , k , v ], dim = 0 )
1469- yield base_name , merged
1470-
1471- # Clear buffer
1472- del qkv_buffer [base_name ]
1473- else :
1474- # Pass through non-qkv weights unchanged
1475- yield name , tensor
1476-
1477- # Check if any incomplete qkv sets remain
1478- if qkv_buffer :
1479- raise ValueError (f"Incomplete QKV weights found: { list (qkv_buffer .keys ())} " )
1480-
1481-
14821380 def load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]) -> set [str ]:
14831381 skip_prefixes = []
1484- #if self.vision_embedding is None:
1485- # skip_prefixes.extend(["vision_embedding."])
1486-
1487- # Usage:
1488- #weights = self.merge_qkv_weights(weights)
1382+
14891383 loader = AutoWeightsLoader (self , skip_prefixes = skip_prefixes )
14901384 return loader .load_weights (weights , mapper = self .hf_to_vllm_mapper )
14911385
0 commit comments