diff --git a/xtuner/v1/loss/base_loss_ctx.py b/xtuner/v1/loss/base_loss_ctx.py index 79680ddb5..1b42bebd0 100644 --- a/xtuner/v1/loss/base_loss_ctx.py +++ b/xtuner/v1/loss/base_loss_ctx.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod -from typing import Annotated, Any, Literal, Self, TypeVar +from typing import Annotated, Any, Literal, TypeVar import torch import torch.distributed as dist @@ -9,6 +9,7 @@ from pydantic import BaseModel, ConfigDict from torch.distributed.device_mesh import DeviceMesh from torch.distributed.nn.functional import all_reduce +from typing_extensions import Self from xtuner.v1.loss.utils import sp_split diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 5aa34ed8e..cd3f02839 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -245,6 +245,7 @@ def __init__( self._total_epochs = total_epochs self._cur_step = 0 + self._global_train_step = 1 if skip_checkpoint_validation: patch_default_save_plan() @@ -554,24 +555,19 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste self.logger.info(f"Prepared {len(data_batches)} training data batches") self._log_data_info(rollout_idx, data_info) - self._writer.add_scalar( - tag="time/onload", - scalar_value=step_timer_dict["onload"], - global_step=rollout_idx, - ) - - self._writer.add_scalar( - tag="time/prepare_data", - scalar_value=step_timer_dict["prepare_data"], - global_step=rollout_idx, - ) - with timer("training", step_timer_dict): workers_log_item: List[WorkerLogItem] = ray.get( self._train_controller.fit.remote( data_batches, pack_max_length=self._train_worker_cfg.pack_max_length, rollout_idx=rollout_idx ) ) + self._log_train_metrics(rollout_idx, step_timer_dict, workers_log_item) + + def _log_train_metrics(self, rollout_idx: int, step_timer_dict: dict, workers_log_item: List[WorkerLogItem]): + self._writer.add_scalar(tag="time/onload", scalar_value=step_timer_dict["onload"], global_step=rollout_idx) + self._writer.add_scalar( + tag="time/prepare_data", scalar_value=step_timer_dict["prepare_data"], global_step=rollout_idx + ) self._writer.add_scalar(tag="time/training", scalar_value=step_timer_dict["training"], global_step=rollout_idx) rank0_log_item = workers_log_item[0] @@ -591,6 +587,7 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste tb_entropy = {"entropy/train": rank0_log_item["train_entropy"]} self._writer.add_scalars(tag_scalar_dict=tb_entropy, global_step=rollout_idx) + train_start_step = self._global_train_step for worker_idx, log_item in enumerate(workers_log_item): mini_batch_metrics: dict[str, List[float]] = {} for mini_batch_log in log_item["train_metrics"]: @@ -599,13 +596,21 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste for k, v in rl_worker_log.items(): mini_batch_metrics.setdefault(k, []).append(cast(float, v)) + for key, value in mini_batch_metrics.items(): + avg_value = sum(value) / len(value) + self._writer.add_scalar( + tag=f"train_metrics/worker_{worker_idx}/step_avg_{key}", + scalar_value=avg_value, + global_step=rollout_idx, + ) + for key, value in mini_batch_metrics.items(): for i, v in enumerate(value): - global_step = (rollout_idx - 1) * len(value) + i + 1 + current_step = train_start_step + i self._writer.add_scalar( tag=f"train_metrics/worker_{worker_idx}/{key}", scalar_value=v, - global_step=global_step, + global_step=current_step, ) rank_sft_log = log_item["sft_train_metrics"] @@ -616,6 +621,9 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste global_step=rollout_idx, ) + num_mini_batches = len(workers_log_item[0]["train_metrics"]) + self._global_train_step += num_mini_batches + def _sync_weights_and_save(self, rollout_idx: int, step_timer_dict: dict): """Synchronizes weights and saves checkpoints.""" with timer("save_ckpt", step_timer_dict):