From dfc78785500ceac9c7682b0fdb43141acf34d325 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 23 May 2026 03:19:48 +0000 Subject: [PATCH] transformerless_lm: continuous self-distillation + cycle checkpointing After PR #4 closed the train/inference omniweight asymmetry, the natural follow-up: don't stop at cycle 6. The active_base ratchet (seed + appended best refined outputs) is exactly the kind of process where compounding past a fixed budget might find regimes the 6-cycle window can't reach. --continuous: replaces `for cycle in range(n_cycles)` with an unbounded loop. n_cycles still controls steps_per_cycle (args.steps // n_cycles) so per-cycle training budget stays calibrated; the cycle counter just keeps going. K-shrink schedule clamps to K_min once global_step exceeds args.steps, which is the standard end state of the curriculum anyway. --checkpoint PATH: serializes the entire distillation state every cycle (model state_dict, FibAdamW optimizer state, active_base, cycle counter, global_step, best_creativity, best_val/step, cycle_summary, rejection counters, best_refined_seq). Atomic write via tmp+os.replace so an interrupt mid-save can't corrupt the file. If the checkpoint exists at startup, training resumes from the saved cycle+1 with the active_base fully intact -- the ratchet picks up exactly where it stopped. Default behavior unchanged: omitting both flags reproduces the v88 + omniweight-loss bounded 6-cycle run. Run a forever-distillation with omniweight-loss: python3 train_self_recursive.py --omniweight-loss \\ --continuous --checkpoint omniweight_distill.pt Resume after Ctrl-C: re-run the same command. Checkpoint state restored, next cycle is start_cycle. --- .../train_self_recursive.py | 74 ++++++++++++++++++- 1 file changed, 70 insertions(+), 4 deletions(-) diff --git a/experiments/transformerless_lm/train_self_recursive.py b/experiments/transformerless_lm/train_self_recursive.py index ade3212..650c4f1 100644 --- a/experiments/transformerless_lm/train_self_recursive.py +++ b/experiments/transformerless_lm/train_self_recursive.py @@ -26,6 +26,7 @@ import argparse import json +import os import sys import time import math @@ -2504,7 +2505,9 @@ def train_with_self_distillation(name, train_seed, corpus_anchor, val_split, distill_prob: float = 0.3, samples_per_cycle: int = 8, keep_top_k: int = 4, - growth_n_new: int = 128): + growth_n_new: int = 128, + continuous: bool = False, + checkpoint_path: str = None): """Self-distillation: model's high-creativity refined outputs become training targets for the next cycle. @@ -2670,12 +2673,43 @@ def quality_fn(seq_tokens): cur_K = None eval_every = max(steps_per_cycle // 4, 100) global_step = 0 + start_cycle = 0 prompt = train_seed[:16].unsqueeze(0) + # Resume from checkpoint if one exists at checkpoint_path. Holds + # the entire distillation state (model, optimizer, active_base, + # cycle counter, best_creativity). The ratchet picks up exactly + # where it stopped. + if checkpoint_path and os.path.exists(checkpoint_path): + print(f" resuming from checkpoint: {checkpoint_path}", flush=True) + ckpt = torch.load(checkpoint_path, map_location='cpu', + weights_only=False) + model.load_state_dict(ckpt['model']) + optimizer.load_state_dict(ckpt['optimizer']) + active_base = ckpt['active_base'] + best_creativity = float(ckpt['best_creativity']) + cycle_summary = ckpt.get('cycle_summary', []) + global_step = int(ckpt.get('global_step', 0)) + start_cycle = int(ckpt.get('cycle', -1)) + 1 + n_rejected_below_baseline = int(ckpt.get('n_rejected_below_baseline', 0)) + n_rejected_anchor = int(ckpt.get('n_rejected_anchor', 0)) + best_val = float(ckpt.get('best_val', float('inf'))) + best_step = int(ckpt.get('best_step', -1)) + if ckpt.get('best_refined_seq') is not None: + best_refined_seq = ckpt['best_refined_seq'] + print(f" resumed at cycle {start_cycle}, " + f"active_base={active_base.numel()} tokens, " + f"best_creativity={best_creativity:.4f}, " + f"global_step={global_step}", flush=True) + # Vocab curriculum disabled for v59 -- v58 showed it hurts mid-cycles. - for cycle in range(n_cycles): + # Continuous mode: loop forever (until interrupted), checkpointing + # every cycle. Otherwise bounded by n_cycles. + cycle_limit = (10 ** 9) if continuous else n_cycles + for cycle in range(start_cycle, cycle_limit): active_vocab_size = None - print(f"\n --- Cycle {cycle+1}/{n_cycles} " + total_label = '∞' if continuous else str(n_cycles) + print(f"\n --- Cycle {cycle+1}/{total_label} " f"active_base_size={active_base.numel()} chars " f"best_creativity={best_creativity:.4f} ---", flush=True) for s in range(steps_per_cycle): @@ -2798,6 +2832,29 @@ def quality_fn(seq_tokens): print(f" best sample (c={kept[0][1]:.3f}):\n " f"{repr(sample_text[:200])}") + # Checkpoint after every cycle. Atomic write via temp file so + # an interrupt mid-save doesn't corrupt the checkpoint. + if checkpoint_path: + tmp = checkpoint_path + ".tmp" + torch.save({ + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'active_base': active_base, + 'cycle': cycle, + 'global_step': global_step, + 'best_creativity': best_creativity, + 'best_val': best_val, + 'best_step': best_step, + 'cycle_summary': cycle_summary, + 'n_rejected_below_baseline': n_rejected_below_baseline, + 'n_rejected_anchor': n_rejected_anchor, + 'best_refined_seq': best_refined_seq, + }, tmp) + os.replace(tmp, checkpoint_path) + print(f" checkpoint saved: {checkpoint_path} " + f"(cycle {cycle+1}, active_base={active_base.numel()} tokens)", + flush=True) + # Final generation for inspection. final_gen = autoregressive_generate(model, prompt, n_new=n_new, vocab_size=vocab_size, @@ -3254,6 +3311,14 @@ def main(): "(phi^pi tanh fluid form) to per-token CE " "during training. Closes the train/inference " "asymmetry on the anti-stagnation primitive.") + parser.add_argument("--continuous", action="store_true", + help="Self-distill indefinitely instead of stopping " + "at n_cycles. Checkpoints every cycle so the " + "run is Ctrl-C-resumable.") + parser.add_argument("--checkpoint", type=str, default=None, + help="Path to checkpoint file. If it exists, " + "training resumes from it. Saved after every " + "cycle. Required for --continuous.") parser.add_argument("--out", type=str, default="results_self_recursive.json") args = parser.parse_args() @@ -3308,7 +3373,8 @@ def main(): itos_map=itos_map, corpus_text=full_corpus_text, vocab_for_bigram=sub_tok.vocab, n_cycles=6, distill_prob=0.3, - samples_per_cycle=8, keep_top_k=4, growth_n_new=128) + samples_per_cycle=8, keep_top_k=4, growth_n_new=128, + continuous=args.continuous, checkpoint_path=args.checkpoint) print() print("=" * 92)