Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xtuner/v1/loss/base_loss_ctx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import Annotated, Any, Literal, Self, TypeVar
from typing import Annotated, Any, Literal, TypeVar

import torch
import torch.distributed as dist
Expand All @@ -9,6 +9,7 @@
from pydantic import BaseModel, ConfigDict
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.nn.functional import all_reduce
from typing_extensions import Self

from xtuner.v1.loss.utils import sp_split

Expand Down
36 changes: 22 additions & 14 deletions xtuner/v1/train/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ def __init__(

self._total_epochs = total_epochs
self._cur_step = 0
self._global_train_step = 1

if skip_checkpoint_validation:
patch_default_save_plan()
Expand Down Expand Up @@ -554,24 +555,19 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste
self.logger.info(f"Prepared {len(data_batches)} training data batches")
self._log_data_info(rollout_idx, data_info)

self._writer.add_scalar(
tag="time/onload",
scalar_value=step_timer_dict["onload"],
global_step=rollout_idx,
)

self._writer.add_scalar(
tag="time/prepare_data",
scalar_value=step_timer_dict["prepare_data"],
global_step=rollout_idx,
)

with timer("training", step_timer_dict):
workers_log_item: List[WorkerLogItem] = ray.get(
self._train_controller.fit.remote(
data_batches, pack_max_length=self._train_worker_cfg.pack_max_length, rollout_idx=rollout_idx
)
)
self._log_train_metrics(rollout_idx, step_timer_dict, workers_log_item)

def _log_train_metrics(self, rollout_idx: int, step_timer_dict: dict, workers_log_item: List[WorkerLogItem]):
self._writer.add_scalar(tag="time/onload", scalar_value=step_timer_dict["onload"], global_step=rollout_idx)
self._writer.add_scalar(
tag="time/prepare_data", scalar_value=step_timer_dict["prepare_data"], global_step=rollout_idx
)
self._writer.add_scalar(tag="time/training", scalar_value=step_timer_dict["training"], global_step=rollout_idx)

rank0_log_item = workers_log_item[0]
Expand All @@ -591,6 +587,7 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste
tb_entropy = {"entropy/train": rank0_log_item["train_entropy"]}
self._writer.add_scalars(tag_scalar_dict=tb_entropy, global_step=rollout_idx)

train_start_step = self._global_train_step
for worker_idx, log_item in enumerate(workers_log_item):
mini_batch_metrics: dict[str, List[float]] = {}
for mini_batch_log in log_item["train_metrics"]:
Expand All @@ -599,13 +596,21 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste
for k, v in rl_worker_log.items():
mini_batch_metrics.setdefault(k, []).append(cast(float, v))

for key, value in mini_batch_metrics.items():
avg_value = sum(value) / len(value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

の値が空のリストになる可能性はあり、0除算エラーが発生することはありますか

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, done

self._writer.add_scalar(
tag=f"train_metrics/worker_{worker_idx}/step_avg_{key}",
scalar_value=avg_value,
global_step=rollout_idx,
)

for key, value in mini_batch_metrics.items():
for i, v in enumerate(value):
global_step = (rollout_idx - 1) * len(value) + i + 1
current_step = train_start_step + i
self._writer.add_scalar(
tag=f"train_metrics/worker_{worker_idx}/{key}",
scalar_value=v,
global_step=global_step,
global_step=current_step,
)

rank_sft_log = log_item["sft_train_metrics"]
Expand All @@ -616,6 +621,9 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste
global_step=rollout_idx,
)

num_mini_batches = len(workers_log_item[0]["train_metrics"])
self._global_train_step += num_mini_batches

def _sync_weights_and_save(self, rollout_idx: int, step_timer_dict: dict):
"""Synchronizes weights and saves checkpoints."""
with timer("save_ckpt", step_timer_dict):
Expand Down