Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions src/maxtext/trainers/post_train/rl/utils_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,15 +535,25 @@ def check_correctness(extracted_response: str, acceptable_answers: list[str], tm
def get_optimizer(tmvp_config: Any, max_train_steps: int) -> optax.GradientTransformation:
"""Function to obtain an optax optimizer, currently we use adamw.

Schedule shape is controlled by `learning_rate_schedule_steps` when set
(>0); this decouples warmup/decay shape from training length so the same
schedule can be applied across runs of different num_batches. Default
(-1) falls back to `max_train_steps` for backward compatibility — matches
the documented behavior of base.yml's `learning_rate_schedule_steps: -1`
("By default the length of the schedule is set to the number of steps").
The LR schedule length defaults to the actual RL run length
(`max_train_steps` = num_batches * num_iterations * train_fraction *
num_epoch). RL does not use the top-level `steps` (a pretraining concept) for
its run length, so the schedule must track `max_train_steps`, not `steps`.

`learning_rate_schedule_steps` may be set to decouple the schedule shape from
the run length (e.g. to match a fixed schedule across runs with different
num_batches). It is honored only as a deliberate override: the config
validator (`MaxTextConfig.set_derived_and_validate_values`) rewrites
`learning_rate_schedule_steps == -1` to `steps` before this runs, so a value
equal to `steps` means "unset" and falls back to `max_train_steps`; only a
value that differs from `steps` is treated as an explicit schedule length.
"""
schedule_steps = getattr(tmvp_config, "learning_rate_schedule_steps", -1)
if schedule_steps is None or schedule_steps <= 0:
lr_schedule_steps = getattr(tmvp_config, "learning_rate_schedule_steps", -1)
config_steps = getattr(tmvp_config, "steps", -1)
if lr_schedule_steps is not None and lr_schedule_steps > 0 and lr_schedule_steps != config_steps:
# Deliberate decoupling of the schedule length from the run length.
schedule_steps = lr_schedule_steps
else:
schedule_steps = max_train_steps
schedule = optax.schedules.warmup_cosine_decay_schedule(
init_value=0.0,
Expand Down
206 changes: 206 additions & 0 deletions tests/post_training/unit/rl_lr_schedule_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Regression tests for the RL learning-rate schedule (CPU-only).

These guard the contract that PR #4029 ("get_optimizer: respect
learning_rate_schedule_steps") was supposed to preserve but silently broke:

With a default RL config (learning_rate_schedule_steps unset == -1) the LR
schedule length must equal the actual training length (max_train_steps), so
warmup completes inside the run.

Why the bug shipped (and why the existing get_optimizer tests miss it):
`tests/post_training/unit/rl_utils_test.py::TestGetOptimizer` builds the config
from a `SimpleNamespace` that does not even carry `learning_rate_schedule_steps`,
so `getattr(cfg, "learning_rate_schedule_steps", -1)` returns -1, the
`schedule_steps <= 0` fallback fires, and the schedule correctly tracks
max_train_steps. The Pydantic validator never runs.

In a real run it DOES run: `MaxTextConfig.set_derived_and_validate_values`
rewrites `learning_rate_schedule_steps == -1 -> steps` (base.yml default 150_001)
*before* get_optimizer is called (configs/types.py). The `<= 0` guard in
get_optimizer therefore can never fire, warmup becomes
`0.1 * 150_001 = 15_000` instead of `0.1 * max_train_steps`, and on a 500-step
run the LR is ~300x too low at the same step.

The only way to catch this is to build the config through the REAL pyconfig
path so the validator runs. That is what these tests do.
"""

import sys
import unittest

import pytest

from maxtext.configs import pyconfig
from maxtext.trainers.post_train.rl import utils_rl
from tests.utils.test_helpers import get_test_config_path

pytestmark = [pytest.mark.post_training]


# Tiny-model overrides known to build a valid MaxTextConfig on CPU without
# network access, mirroring tests/post_training/unit/lora_utils_test.py. The
# model shape is irrelevant here (get_optimizer only reads scalar config
# fields); these just let initialize_pydantic validate quickly.
_MODEL_OVERRIDES = {
"per_device_batch_size": 1.0,
"enable_checkpointing": False,
"base_num_decoder_layers": 1,
"attention": "dot_product",
"max_target_length": 8,
"base_emb_dim": 128,
"base_num_query_heads": 2,
"base_num_kv_heads": 2,
"base_mlp_dim": 256,
"max_prefill_predict_length": 4,
"model_name": "llama2-7b",
"enable_nnx": True,
"pure_nnx_decoder": True,
"override_model_config": True,
"weight_dtype": "bfloat16",
}


def _make_config(**overrides):
"""Build a real MaxTextConfig (runs the Pydantic validator) for RL.

Using initialize_pydantic (not a SimpleNamespace) is the whole point: it is
what promotes learning_rate_schedule_steps == -1 to `steps`.
"""
return pyconfig.initialize_pydantic(
[sys.argv[0], get_test_config_path()],
run_name="rl_lr_schedule_test",
**_MODEL_OVERRIDES,
**overrides,
)


def _max_train_steps(config):
"""Mirror of train_rl.get_max_train_steps.

Inlined rather than imported because importing train_rl pulls the heavy RL
training stack (tunix/vLLM). The formula itself is part of the contract under
test, so a drift here is a meaningful signal.
"""
return int(config.num_batches * config.rl.num_iterations * config.train_fraction * config.num_epoch)


def _effective_lr_at_step(opt, step):
"""Step an inject_hyperparams optimizer `step` times and read the LR it
exposes in opt_state.hyperparams. This is exactly the per-step LR that
tunix's peft_trainer reads and logs, so it reflects what training actually
sees, not a re-derivation of the schedule.
"""
import jax.numpy as jnp # pylint: disable=import-outside-toplevel

params = {"w": jnp.zeros((), dtype=jnp.float32)}
grads = {"w": jnp.zeros((), dtype=jnp.float32)}
state = opt.init(params)
for _ in range(step):
_, state = opt.update(grads, state, params)
return float(state.hyperparams["learning_rate"])


class RLLearningRateScheduleTest(unittest.TestCase):
"""Schedule-shape guards for utils_rl.get_optimizer built on a real config."""

@pytest.mark.cpu_only
def test_default_rl_config_warms_up_within_run(self):
"""REGRESSION GUARD (FAILS on PR #4029, passes once fixed).

With learning_rate_schedule_steps unset (-1) and base.yml's default
`steps` (150_001), the LR must still reach its configured peak by the end
of the intended warmup (0.1 * max_train_steps). On the buggy code the
warmup is sized to 150_001, so the LR is stuck near zero for the whole run.
"""
peak = 3e-6
config = _make_config(
learning_rate=peak,
warmup_steps_fraction=0.1,
gradient_clipping_threshold=0.0,
num_batches=500,
num_epoch=1,
train_fraction=1.0,
steps=150_001, # base.yml default (a pretraining-sized number)
learning_rate_schedule_steps=-1, # "user did not set it" (base.yml default)
)
max_train_steps = _max_train_steps(config)
# Sanity: we are in the regime where the bug shows (run << base `steps`).
self.assertLess(max_train_steps, config.steps)

opt = utils_rl.get_optimizer(config, max_train_steps)
warmup_end = int(config.warmup_steps_fraction * max_train_steps)
lr = _effective_lr_at_step(opt, warmup_end + 5)
# Correct: lr ~= peak. Buggy (#4029): lr ~= peak * warmup_end / 15000,
# i.e. < 1% of peak. A 0.5*peak threshold cleanly separates the two.
self.assertGreaterEqual(
lr,
0.5 * peak,
msg=(
f"LR reached only {lr:.2e} by step {warmup_end + 5} (peak={peak:.2e}). "
"Warmup was sized to base.yml `steps`, not max_train_steps "
f"(={max_train_steps}). learning_rate_schedule_steps in effect="
f"{config.learning_rate_schedule_steps}."
),
)

@pytest.mark.cpu_only
def test_explicit_schedule_steps_decouples_from_run_length(self):
"""FEATURE GUARD (passes before and after the fix).

An explicit learning_rate_schedule_steps must drive the schedule shape
independently of max_train_steps (the capability #4029 was adding). This
ensures the regression fix does not simply delete the feature.
"""
peak = 3e-6
schedule_len = 1000
config = _make_config(
learning_rate=peak,
warmup_steps_fraction=0.1,
gradient_clipping_threshold=0.0,
num_batches=50,
num_epoch=1,
train_fraction=1.0,
learning_rate_schedule_steps=schedule_len, # explicitly set by the user
)
max_train_steps = _max_train_steps(config)

opt = utils_rl.get_optimizer(config, max_train_steps)
# The warmup length must follow the explicit schedule (0.1 * 1000 = 100),
# independent of max_train_steps. Probe at fixed steps so the assertion does
# not depend on num_iterations: early in a 1000-step warmup the LR is still
# small, and by ~100 steps it has reached peak. A fix that wrongly forced
# schedule == max_train_steps (a short warmup) would push LR@20 to ~peak and
# trip the first assertion.
self.assertLessEqual(_effective_lr_at_step(opt, 20), 0.4 * peak)
self.assertGreaterEqual(_effective_lr_at_step(opt, int(0.1 * schedule_len) + 5), 0.9 * peak)

@pytest.mark.cpu_only
def test_validator_overwrites_minus_one_sentinel(self):
"""ROOT-CAUSE characterization (effective-value assertion).

The value get_optimizer reads is NOT the -1 the user left: the validator
promoted it to `steps`. This is precisely why get_optimizer's `<= 0`
fallback is dead code in a real run. (Passes today; documents the seam
that makes test_default_rl_config_warms_up_within_run fail.)
"""
config = _make_config(steps=150_001, learning_rate_schedule_steps=-1)
self.assertNotEqual(config.learning_rate_schedule_steps, -1)
self.assertEqual(config.learning_rate_schedule_steps, config.steps)


if __name__ == "__main__":
unittest.main()
Loading