From 0dfb9eb2227c9f949a1186e3f39bff9cdfbc2a7a Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Fri, 30 Jan 2026 14:04:09 +0800 Subject: [PATCH 1/3] fix rl_trainer train_metrics global_step and add avg train_metrics --- xtuner/v1/loss/base_loss_ctx.py | 3 ++- xtuner/v1/train/rl_trainer.py | 13 ++++++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/xtuner/v1/loss/base_loss_ctx.py b/xtuner/v1/loss/base_loss_ctx.py index 79680ddb5..1b42bebd0 100644 --- a/xtuner/v1/loss/base_loss_ctx.py +++ b/xtuner/v1/loss/base_loss_ctx.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from abc import ABC, abstractmethod -from typing import Annotated, Any, Literal, Self, TypeVar +from typing import Annotated, Any, Literal, TypeVar import torch import torch.distributed as dist @@ -9,6 +9,7 @@ from pydantic import BaseModel, ConfigDict from torch.distributed.device_mesh import DeviceMesh from torch.distributed.nn.functional import all_reduce +from typing_extensions import Self from xtuner.v1.loss.utils import sp_split diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 5aa34ed8e..979917242 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -245,6 +245,7 @@ def __init__( self._total_epochs = total_epochs self._cur_step = 0 + self._train_mini_step = 1 if skip_checkpoint_validation: patch_default_save_plan() @@ -599,9 +600,17 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste for k, v in rl_worker_log.items(): mini_batch_metrics.setdefault(k, []).append(cast(float, v)) + for key, value in mini_batch_metrics.items(): + avg_value = sum(value) / len(value) + self._writer.add_scalar( + tag=f"train_metrics/worker_{worker_idx}/step_avg_{key}", + scalar_value=avg_value, + global_step=rollout_idx, + ) + for key, value in mini_batch_metrics.items(): for i, v in enumerate(value): - global_step = (rollout_idx - 1) * len(value) + i + 1 + global_step = self._train_mini_step + i self._writer.add_scalar( tag=f"train_metrics/worker_{worker_idx}/{key}", scalar_value=v, @@ -616,6 +625,8 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste global_step=rollout_idx, ) + self._train_mini_step += len(workers_log_item[0]["train_metrics"]) + def _sync_weights_and_save(self, rollout_idx: int, step_timer_dict: dict): """Synchronizes weights and saves checkpoints.""" with timer("save_ckpt", step_timer_dict): From 52113198098e376e77af4692f9474d145540fb2a Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Fri, 30 Jan 2026 14:46:32 +0800 Subject: [PATCH 2/3] Handle the case where the length of metrics is zero. --- xtuner/v1/train/rl_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 979917242..75209ee6c 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -601,6 +601,8 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste mini_batch_metrics.setdefault(k, []).append(cast(float, v)) for key, value in mini_batch_metrics.items(): + if len(value) == 0: + continue avg_value = sum(value) / len(value) self._writer.add_scalar( tag=f"train_metrics/worker_{worker_idx}/step_avg_{key}", From c3edb66aae7ed895dc36919b17d020a4e809a517 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Mon, 2 Feb 2026 15:35:12 +0800 Subject: [PATCH 3/3] fix comments --- xtuner/v1/train/rl_trainer.py | 31 +++++++++++++------------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 75209ee6c..cd3f02839 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -245,7 +245,7 @@ def __init__( self._total_epochs = total_epochs self._cur_step = 0 - self._train_mini_step = 1 + self._global_train_step = 1 if skip_checkpoint_validation: patch_default_save_plan() @@ -555,24 +555,19 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste self.logger.info(f"Prepared {len(data_batches)} training data batches") self._log_data_info(rollout_idx, data_info) - self._writer.add_scalar( - tag="time/onload", - scalar_value=step_timer_dict["onload"], - global_step=rollout_idx, - ) - - self._writer.add_scalar( - tag="time/prepare_data", - scalar_value=step_timer_dict["prepare_data"], - global_step=rollout_idx, - ) - with timer("training", step_timer_dict): workers_log_item: List[WorkerLogItem] = ray.get( self._train_controller.fit.remote( data_batches, pack_max_length=self._train_worker_cfg.pack_max_length, rollout_idx=rollout_idx ) ) + self._log_train_metrics(rollout_idx, step_timer_dict, workers_log_item) + + def _log_train_metrics(self, rollout_idx: int, step_timer_dict: dict, workers_log_item: List[WorkerLogItem]): + self._writer.add_scalar(tag="time/onload", scalar_value=step_timer_dict["onload"], global_step=rollout_idx) + self._writer.add_scalar( + tag="time/prepare_data", scalar_value=step_timer_dict["prepare_data"], global_step=rollout_idx + ) self._writer.add_scalar(tag="time/training", scalar_value=step_timer_dict["training"], global_step=rollout_idx) rank0_log_item = workers_log_item[0] @@ -592,6 +587,7 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste tb_entropy = {"entropy/train": rank0_log_item["train_entropy"]} self._writer.add_scalars(tag_scalar_dict=tb_entropy, global_step=rollout_idx) + train_start_step = self._global_train_step for worker_idx, log_item in enumerate(workers_log_item): mini_batch_metrics: dict[str, List[float]] = {} for mini_batch_log in log_item["train_metrics"]: @@ -601,8 +597,6 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste mini_batch_metrics.setdefault(k, []).append(cast(float, v)) for key, value in mini_batch_metrics.items(): - if len(value) == 0: - continue avg_value = sum(value) / len(value) self._writer.add_scalar( tag=f"train_metrics/worker_{worker_idx}/step_avg_{key}", @@ -612,11 +606,11 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste for key, value in mini_batch_metrics.items(): for i, v in enumerate(value): - global_step = self._train_mini_step + i + current_step = train_start_step + i self._writer.add_scalar( tag=f"train_metrics/worker_{worker_idx}/{key}", scalar_value=v, - global_step=global_step, + global_step=current_step, ) rank_sft_log = log_item["sft_train_metrics"] @@ -627,7 +621,8 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste global_step=rollout_idx, ) - self._train_mini_step += len(workers_log_item[0]["train_metrics"]) + num_mini_batches = len(workers_log_item[0]["train_metrics"]) + self._global_train_step += num_mini_batches def _sync_weights_and_save(self, rollout_idx: int, step_timer_dict: dict): """Synchronizes weights and saves checkpoints."""