diff --git a/experiments/transformerless_lm/train_self_recursive.py b/experiments/transformerless_lm/train_self_recursive.py index ade3212..650c4f1 100644 --- a/experiments/transformerless_lm/train_self_recursive.py +++ b/experiments/transformerless_lm/train_self_recursive.py @@ -26,6 +26,7 @@ import argparse import json +import os import sys import time import math @@ -2504,7 +2505,9 @@ def train_with_self_distillation(name, train_seed, corpus_anchor, val_split, distill_prob: float = 0.3, samples_per_cycle: int = 8, keep_top_k: int = 4, - growth_n_new: int = 128): + growth_n_new: int = 128, + continuous: bool = False, + checkpoint_path: str = None): """Self-distillation: model's high-creativity refined outputs become training targets for the next cycle. @@ -2670,12 +2673,43 @@ def quality_fn(seq_tokens): cur_K = None eval_every = max(steps_per_cycle // 4, 100) global_step = 0 + start_cycle = 0 prompt = train_seed[:16].unsqueeze(0) + # Resume from checkpoint if one exists at checkpoint_path. Holds + # the entire distillation state (model, optimizer, active_base, + # cycle counter, best_creativity). The ratchet picks up exactly + # where it stopped. + if checkpoint_path and os.path.exists(checkpoint_path): + print(f" resuming from checkpoint: {checkpoint_path}", flush=True) + ckpt = torch.load(checkpoint_path, map_location='cpu', + weights_only=False) + model.load_state_dict(ckpt['model']) + optimizer.load_state_dict(ckpt['optimizer']) + active_base = ckpt['active_base'] + best_creativity = float(ckpt['best_creativity']) + cycle_summary = ckpt.get('cycle_summary', []) + global_step = int(ckpt.get('global_step', 0)) + start_cycle = int(ckpt.get('cycle', -1)) + 1 + n_rejected_below_baseline = int(ckpt.get('n_rejected_below_baseline', 0)) + n_rejected_anchor = int(ckpt.get('n_rejected_anchor', 0)) + best_val = float(ckpt.get('best_val', float('inf'))) + best_step = int(ckpt.get('best_step', -1)) + if ckpt.get('best_refined_seq') is not None: + best_refined_seq = ckpt['best_refined_seq'] + print(f" resumed at cycle {start_cycle}, " + f"active_base={active_base.numel()} tokens, " + f"best_creativity={best_creativity:.4f}, " + f"global_step={global_step}", flush=True) + # Vocab curriculum disabled for v59 -- v58 showed it hurts mid-cycles. - for cycle in range(n_cycles): + # Continuous mode: loop forever (until interrupted), checkpointing + # every cycle. Otherwise bounded by n_cycles. + cycle_limit = (10 ** 9) if continuous else n_cycles + for cycle in range(start_cycle, cycle_limit): active_vocab_size = None - print(f"\n --- Cycle {cycle+1}/{n_cycles} " + total_label = '∞' if continuous else str(n_cycles) + print(f"\n --- Cycle {cycle+1}/{total_label} " f"active_base_size={active_base.numel()} chars " f"best_creativity={best_creativity:.4f} ---", flush=True) for s in range(steps_per_cycle): @@ -2798,6 +2832,29 @@ def quality_fn(seq_tokens): print(f" best sample (c={kept[0][1]:.3f}):\n " f"{repr(sample_text[:200])}") + # Checkpoint after every cycle. Atomic write via temp file so + # an interrupt mid-save doesn't corrupt the checkpoint. + if checkpoint_path: + tmp = checkpoint_path + ".tmp" + torch.save({ + 'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'active_base': active_base, + 'cycle': cycle, + 'global_step': global_step, + 'best_creativity': best_creativity, + 'best_val': best_val, + 'best_step': best_step, + 'cycle_summary': cycle_summary, + 'n_rejected_below_baseline': n_rejected_below_baseline, + 'n_rejected_anchor': n_rejected_anchor, + 'best_refined_seq': best_refined_seq, + }, tmp) + os.replace(tmp, checkpoint_path) + print(f" checkpoint saved: {checkpoint_path} " + f"(cycle {cycle+1}, active_base={active_base.numel()} tokens)", + flush=True) + # Final generation for inspection. final_gen = autoregressive_generate(model, prompt, n_new=n_new, vocab_size=vocab_size, @@ -3254,6 +3311,14 @@ def main(): "(phi^pi tanh fluid form) to per-token CE " "during training. Closes the train/inference " "asymmetry on the anti-stagnation primitive.") + parser.add_argument("--continuous", action="store_true", + help="Self-distill indefinitely instead of stopping " + "at n_cycles. Checkpoints every cycle so the " + "run is Ctrl-C-resumable.") + parser.add_argument("--checkpoint", type=str, default=None, + help="Path to checkpoint file. If it exists, " + "training resumes from it. Saved after every " + "cycle. Required for --continuous.") parser.add_argument("--out", type=str, default="results_self_recursive.json") args = parser.parse_args() @@ -3308,7 +3373,8 @@ def main(): itos_map=itos_map, corpus_text=full_corpus_text, vocab_for_bigram=sub_tok.vocab, n_cycles=6, distill_prob=0.3, - samples_per_cycle=8, keep_top_k=4, growth_n_new=128) + samples_per_cycle=8, keep_top_k=4, growth_n_new=128, + continuous=args.continuous, checkpoint_path=args.checkpoint) print() print("=" * 92)