Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 70 additions & 4 deletions experiments/transformerless_lm/train_self_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import argparse
import json
import os
import sys
import time
import math
Expand Down Expand Up @@ -2504,7 +2505,9 @@ def train_with_self_distillation(name, train_seed, corpus_anchor, val_split,
distill_prob: float = 0.3,
samples_per_cycle: int = 8,
keep_top_k: int = 4,
growth_n_new: int = 128):
growth_n_new: int = 128,
continuous: bool = False,
checkpoint_path: str = None):
"""Self-distillation: model's high-creativity refined outputs become
training targets for the next cycle.

Expand Down Expand Up @@ -2670,12 +2673,43 @@ def quality_fn(seq_tokens):
cur_K = None
eval_every = max(steps_per_cycle // 4, 100)
global_step = 0
start_cycle = 0
prompt = train_seed[:16].unsqueeze(0)

# Resume from checkpoint if one exists at checkpoint_path. Holds
# the entire distillation state (model, optimizer, active_base,
# cycle counter, best_creativity). The ratchet picks up exactly
# where it stopped.
if checkpoint_path and os.path.exists(checkpoint_path):
print(f" resuming from checkpoint: {checkpoint_path}", flush=True)
ckpt = torch.load(checkpoint_path, map_location='cpu',
weights_only=False)
model.load_state_dict(ckpt['model'])
optimizer.load_state_dict(ckpt['optimizer'])
active_base = ckpt['active_base']
best_creativity = float(ckpt['best_creativity'])
cycle_summary = ckpt.get('cycle_summary', [])
global_step = int(ckpt.get('global_step', 0))
start_cycle = int(ckpt.get('cycle', -1)) + 1
n_rejected_below_baseline = int(ckpt.get('n_rejected_below_baseline', 0))
n_rejected_anchor = int(ckpt.get('n_rejected_anchor', 0))
best_val = float(ckpt.get('best_val', float('inf')))
best_step = int(ckpt.get('best_step', -1))
if ckpt.get('best_refined_seq') is not None:
best_refined_seq = ckpt['best_refined_seq']
print(f" resumed at cycle {start_cycle}, "
f"active_base={active_base.numel()} tokens, "
f"best_creativity={best_creativity:.4f}, "
f"global_step={global_step}", flush=True)

# Vocab curriculum disabled for v59 -- v58 showed it hurts mid-cycles.
for cycle in range(n_cycles):
# Continuous mode: loop forever (until interrupted), checkpointing
# every cycle. Otherwise bounded by n_cycles.
cycle_limit = (10 ** 9) if continuous else n_cycles
for cycle in range(start_cycle, cycle_limit):
active_vocab_size = None
print(f"\n --- Cycle {cycle+1}/{n_cycles} "
total_label = '∞' if continuous else str(n_cycles)
print(f"\n --- Cycle {cycle+1}/{total_label} "
f"active_base_size={active_base.numel()} chars "
f"best_creativity={best_creativity:.4f} ---", flush=True)
for s in range(steps_per_cycle):
Expand Down Expand Up @@ -2798,6 +2832,29 @@ def quality_fn(seq_tokens):
print(f" best sample (c={kept[0][1]:.3f}):\n "
f"{repr(sample_text[:200])}")

# Checkpoint after every cycle. Atomic write via temp file so
# an interrupt mid-save doesn't corrupt the checkpoint.
if checkpoint_path:
tmp = checkpoint_path + ".tmp"
torch.save({
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'active_base': active_base,
'cycle': cycle,
'global_step': global_step,
'best_creativity': best_creativity,
'best_val': best_val,
'best_step': best_step,
'cycle_summary': cycle_summary,
'n_rejected_below_baseline': n_rejected_below_baseline,
'n_rejected_anchor': n_rejected_anchor,
'best_refined_seq': best_refined_seq,
}, tmp)
os.replace(tmp, checkpoint_path)
print(f" checkpoint saved: {checkpoint_path} "
f"(cycle {cycle+1}, active_base={active_base.numel()} tokens)",
flush=True)

# Final generation for inspection.
final_gen = autoregressive_generate(model, prompt, n_new=n_new,
vocab_size=vocab_size,
Expand Down Expand Up @@ -3254,6 +3311,14 @@ def main():
"(phi^pi tanh fluid form) to per-token CE "
"during training. Closes the train/inference "
"asymmetry on the anti-stagnation primitive.")
parser.add_argument("--continuous", action="store_true",
help="Self-distill indefinitely instead of stopping "
"at n_cycles. Checkpoints every cycle so the "
"run is Ctrl-C-resumable.")
parser.add_argument("--checkpoint", type=str, default=None,
help="Path to checkpoint file. If it exists, "
"training resumes from it. Saved after every "
"cycle. Required for --continuous.")
parser.add_argument("--out", type=str,
default="results_self_recursive.json")
args = parser.parse_args()
Expand Down Expand Up @@ -3308,7 +3373,8 @@ def main():
itos_map=itos_map, corpus_text=full_corpus_text,
vocab_for_bigram=sub_tok.vocab,
n_cycles=6, distill_prob=0.3,
samples_per_cycle=8, keep_top_k=4, growth_n_new=128)
samples_per_cycle=8, keep_top_k=4, growth_n_new=128,
continuous=args.continuous, checkpoint_path=args.checkpoint)

print()
print("=" * 92)
Expand Down
Loading