diff --git a/xtuner/v1/engine/train_engine.py b/xtuner/v1/engine/train_engine.py index 459549650..c6efcc6ea 100644 --- a/xtuner/v1/engine/train_engine.py +++ b/xtuner/v1/engine/train_engine.py @@ -411,12 +411,13 @@ def step_optimizer(self, grad_norm): self.optimizer.zero_grad() return grad_norm + # TODO: Should be removed @staticmethod def clean_param_name(name: str) -> str: - if "._checkpoint_wrapped_module." in name: - name = name.replace("._checkpoint_wrapped_module.", ".") - if "._orig_mod." in name: - name = name.replace("._orig_mod.", ".") + if "_checkpoint_wrapped_module." in name: + name = name.replace("_checkpoint_wrapped_module.", "") + if "_orig_mod." in name: + name = name.replace("_orig_mod.", "") return name def save_hf(self, hf_dir: str, save_dtype: torch.dtype = torch.bfloat16): diff --git a/xtuner/v1/model/base.py b/xtuner/v1/model/base.py index 5148f29d9..b3b74eb8e 100644 --- a/xtuner/v1/model/base.py +++ b/xtuner/v1/model/base.py @@ -85,6 +85,9 @@ class XTunerBaseModelConfig(PydanticBaseModel): def hf_config(self) -> PretrainedConfig | None: raise NotImplementedError + def build(self): + raise NotImplementedError + DEFAULT_FLOAT8_CFG = { "xtuner.v1.float8.fsdp_utils.tensor_to_per_block_fp8_scales": TorchCompileOption(fullgraph=True), @@ -250,7 +253,11 @@ def __init__(self, config: XTunerBaseModelConfig): def set_hf(self, hf_path: str | Path): self._hf_path = Path(hf_path) - def from_hf(self, hf_path: str | Path, strict: bool = True) -> tuple: + def from_hf( + self, hf_path: str | Path, strict: bool = True + ) -> tuple[ + Annotated[set[str], "loaded keys"], Annotated[set[str], "unloaded keys"], Annotated[set[str], "missing keys"] + ]: self._hf_path = Path(hf_path) if isinstance(hf_path, Path): @@ -348,7 +355,7 @@ def init_weights(self): from xtuner.v1.utils import default_init_weights initialized_params = default_init_weights(self) - if missing := {name for name, _ in self.named_parameters()} - initialized_params: + if missing := {self._clean_param_name(name) for name, _ in self.named_parameters()} - initialized_params: raise RuntimeError(f"{missing} is not initialized") def _init_load_spec(self) -> None: @@ -797,11 +804,12 @@ def _get_same_hf_param( if buffer_tensor_list: yield buffer_name_list, buffer_tensor_list + # TODO: Using `xtuenr.v1.utils.misc.clean_param_name` def _clean_param_name(self, name: str) -> str: - if "._checkpoint_wrapped_module." in name: - name = name.replace("._checkpoint_wrapped_module.", ".") - if "._orig_mod." in name: - name = name.replace("._orig_mod.", ".") + if "_checkpoint_wrapped_module." in name: + name = name.replace("_checkpoint_wrapped_module.", "") + if "_orig_mod." in name: + name = name.replace("_orig_mod.", "") return name def _group_param_by_load_spec(self, load_enum: LoadEnum): diff --git a/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py b/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py index 38d417f84..57eb517f9 100644 --- a/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py +++ b/xtuner/v1/model/compose/qwen3_vl/modeling_vision.py @@ -276,6 +276,7 @@ def init_weights(self): for layer_idx, layer in enumerate(self.blocks): for name, module in layer.named_modules(): + name = self._clean_param_name(name) if isinstance(module, nn.Linear): init_params(module.weight, partial(torch.nn.init.normal_, mean=0.0, std=self.config.initializer_range)) diff --git a/xtuner/v1/train/trainer.py b/xtuner/v1/train/trainer.py index 8271c102b..f2c095cc3 100644 --- a/xtuner/v1/train/trainer.py +++ b/xtuner/v1/train/trainer.py @@ -33,7 +33,7 @@ from xtuner.v1.engine import LossLog, OtherLog, TrainEngine from xtuner.v1.engine.vision_compose_train_engine import VisionComposeTrainEngine from xtuner.v1.loss import CELossConfig, CELossContext -from xtuner.v1.model.base import ModelItem, TransformerConfig, XTunerBaseModelConfig +from xtuner.v1.model.base import ModelItem, XTunerBaseModelConfig from xtuner.v1.model.compose.base import BaseComposeConfig from xtuner.v1.model.moe.moe import MoEConfig from xtuner.v1.patch import patch_default_save_plan diff --git a/xtuner/v1/utils/__init__.py b/xtuner/v1/utils/__init__.py index a3241e773..691893783 100644 --- a/xtuner/v1/utils/__init__.py +++ b/xtuner/v1/utils/__init__.py @@ -11,6 +11,7 @@ XTUNER_DETERMINISTIC, FunctionEnum, SharedMemory, + clean_param_name, get_function_type, get_padding_length, is_hf_model_path, @@ -57,4 +58,5 @@ "monkey_unpatch_torch_reductions", "ray_method", "profile_time", + "clean_param_name", ] diff --git a/xtuner/v1/utils/init_weight.py b/xtuner/v1/utils/init_weight.py index 0bb5153f5..ed0fce487 100644 --- a/xtuner/v1/utils/init_weight.py +++ b/xtuner/v1/utils/init_weight.py @@ -5,7 +5,8 @@ import torch.nn as nn from torch.distributed.tensor import DTensor, distribute_tensor -from xtuner.v1.utils import get_device +from .device import get_device +from .misc import clean_param_name DEVICE = get_device() @@ -51,7 +52,7 @@ def _default_init_atom(name: str, module: nn.Module): if hasattr(module, "bias") and module.bias is not None: bias = cast(torch.Tensor, module.bias) init_params(bias, nn.init.zeros_) - initialized_params.add(f"{name}.bias") + initialized_params.add(clean_param_name(f"{name}.bias")) if hasattr(module, "weight") and module.weight is not None: weight = cast(torch.Tensor, module.weight) @@ -59,7 +60,7 @@ def _default_init_atom(name: str, module: nn.Module): init_params(weight, nn.init.ones_) else: init_params(weight, partial(nn.init.normal_, mean=0.0, std=0.02)) - initialized_params.add(f"{name}.weight") + initialized_params.add(clean_param_name(f"{name}.weight")) _init_weights_recursive("", module) return initialized_params diff --git a/xtuner/v1/utils/misc.py b/xtuner/v1/utils/misc.py index 20fae0cc8..e9aaf82bb 100644 --- a/xtuner/v1/utils/misc.py +++ b/xtuner/v1/utils/misc.py @@ -192,3 +192,11 @@ def get_function_full_qualname(function: FunctionType) -> str: full_qualname = f"{module_name}.{qualname}" return full_qualname + + +def clean_param_name(name: str) -> str: + if "_checkpoint_wrapped_module." in name: + name = name.replace("_checkpoint_wrapped_module.", "") + if "_orig_mod." in name: + name = name.replace("_orig_mod.", "") + return name