Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions xtuner/v1/engine/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,12 +411,13 @@ def step_optimizer(self, grad_norm):
self.optimizer.zero_grad()
return grad_norm

# TODO: Should be removed
@staticmethod
def clean_param_name(name: str) -> str:
if "._checkpoint_wrapped_module." in name:
name = name.replace("._checkpoint_wrapped_module.", ".")
if "._orig_mod." in name:
name = name.replace("._orig_mod.", ".")
if "_checkpoint_wrapped_module." in name:
name = name.replace("_checkpoint_wrapped_module.", "")
if "_orig_mod." in name:
name = name.replace("_orig_mod.", "")
return name

def save_hf(self, hf_dir: str, save_dtype: torch.dtype = torch.bfloat16):
Expand Down
20 changes: 14 additions & 6 deletions xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ class XTunerBaseModelConfig(PydanticBaseModel):
def hf_config(self) -> PretrainedConfig | None:
raise NotImplementedError

def build(self):
raise NotImplementedError


DEFAULT_FLOAT8_CFG = {
"xtuner.v1.float8.fsdp_utils.tensor_to_per_block_fp8_scales": TorchCompileOption(fullgraph=True),
Expand Down Expand Up @@ -250,7 +253,11 @@ def __init__(self, config: XTunerBaseModelConfig):
def set_hf(self, hf_path: str | Path):
self._hf_path = Path(hf_path)

def from_hf(self, hf_path: str | Path, strict: bool = True) -> tuple:
def from_hf(
self, hf_path: str | Path, strict: bool = True
) -> tuple[
Annotated[set[str], "loaded keys"], Annotated[set[str], "unloaded keys"], Annotated[set[str], "missing keys"]
]:
self._hf_path = Path(hf_path)

if isinstance(hf_path, Path):
Expand Down Expand Up @@ -348,7 +355,7 @@ def init_weights(self):
from xtuner.v1.utils import default_init_weights

initialized_params = default_init_weights(self)
if missing := {name for name, _ in self.named_parameters()} - initialized_params:
if missing := {self._clean_param_name(name) for name, _ in self.named_parameters()} - initialized_params:
raise RuntimeError(f"{missing} is not initialized")

def _init_load_spec(self) -> None:
Expand Down Expand Up @@ -797,11 +804,12 @@ def _get_same_hf_param(
if buffer_tensor_list:
yield buffer_name_list, buffer_tensor_list

# TODO: Using `xtuenr.v1.utils.misc.clean_param_name`
def _clean_param_name(self, name: str) -> str:
if "._checkpoint_wrapped_module." in name:
name = name.replace("._checkpoint_wrapped_module.", ".")
if "._orig_mod." in name:
name = name.replace("._orig_mod.", ".")
if "_checkpoint_wrapped_module." in name:
name = name.replace("_checkpoint_wrapped_module.", "")
if "_orig_mod." in name:
name = name.replace("_orig_mod.", "")
return name

def _group_param_by_load_spec(self, load_enum: LoadEnum):
Expand Down
1 change: 1 addition & 0 deletions xtuner/v1/model/compose/qwen3_vl/modeling_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def init_weights(self):

for layer_idx, layer in enumerate(self.blocks):
for name, module in layer.named_modules():
name = self._clean_param_name(name)
if isinstance(module, nn.Linear):
init_params(module.weight,
partial(torch.nn.init.normal_, mean=0.0, std=self.config.initializer_range))
Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from xtuner.v1.engine import LossLog, OtherLog, TrainEngine
from xtuner.v1.engine.vision_compose_train_engine import VisionComposeTrainEngine
from xtuner.v1.loss import CELossConfig, CELossContext
from xtuner.v1.model.base import ModelItem, TransformerConfig, XTunerBaseModelConfig
from xtuner.v1.model.base import ModelItem, XTunerBaseModelConfig
from xtuner.v1.model.compose.base import BaseComposeConfig
from xtuner.v1.model.moe.moe import MoEConfig
from xtuner.v1.patch import patch_default_save_plan
Expand Down
2 changes: 2 additions & 0 deletions xtuner/v1/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
XTUNER_DETERMINISTIC,
FunctionEnum,
SharedMemory,
clean_param_name,
get_function_type,
get_padding_length,
is_hf_model_path,
Expand Down Expand Up @@ -57,4 +58,5 @@
"monkey_unpatch_torch_reductions",
"ray_method",
"profile_time",
"clean_param_name",
]
7 changes: 4 additions & 3 deletions xtuner/v1/utils/init_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torch.nn as nn
from torch.distributed.tensor import DTensor, distribute_tensor

from xtuner.v1.utils import get_device
from .device import get_device
from .misc import clean_param_name


DEVICE = get_device()
Expand Down Expand Up @@ -51,15 +52,15 @@ def _default_init_atom(name: str, module: nn.Module):
if hasattr(module, "bias") and module.bias is not None:
bias = cast(torch.Tensor, module.bias)
init_params(bias, nn.init.zeros_)
initialized_params.add(f"{name}.bias")
initialized_params.add(clean_param_name(f"{name}.bias"))

if hasattr(module, "weight") and module.weight is not None:
weight = cast(torch.Tensor, module.weight)
if "norm" in name:
init_params(weight, nn.init.ones_)
else:
init_params(weight, partial(nn.init.normal_, mean=0.0, std=0.02))
initialized_params.add(f"{name}.weight")
initialized_params.add(clean_param_name(f"{name}.weight"))

_init_weights_recursive("", module)
return initialized_params
8 changes: 8 additions & 0 deletions xtuner/v1/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,11 @@ def get_function_full_qualname(function: FunctionType) -> str:

full_qualname = f"{module_name}.{qualname}"
return full_qualname


def clean_param_name(name: str) -> str:
if "_checkpoint_wrapped_module." in name:
name = name.replace("_checkpoint_wrapped_module.", "")
if "_orig_mod." in name:
name = name.replace("_orig_mod.", "")
return name
Loading