diff --git a/README.md b/README.md index e672ea9..39a7ced 100644 --- a/README.md +++ b/README.md @@ -154,24 +154,38 @@ Full loss curves and per-epoch breakdown: [summary.md](summary.md) ### Loading from Python ```python -import torch -from src.model import UNetModel, DDIMScheduler +import sys, torch +sys.path.insert(0, "src") # make src/ importable + from huggingface_hub import hf_hub_download +from SD_Model import UNetModel # legacy single-file module +# — or, equivalently, the refactored module: from model import UNetModel checkpoint = hf_hub_download( repo_id="atandra2000/sd-from-scratch-v1", filename="sd_epoch_042.pt", + local_dir="checkpoints", ) - ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False) -unet_sd = ckpt["unet_state_dict"] +# Load EMA shadow (produces better images than live weights) unet = UNetModel(in_ch=4, out_ch=4, ch=320, res_blks=2, attn_lvls=(1, 2, 3), ch_mults=(1, 2, 4, 4), heads=8, ctx_dim=768) -unet.load_state_dict(unet_sd, strict=True) +shadow = ckpt["ema_state_dict"]["shadow"] +cleaned = {} +for k, v in shadow.items(): + for prefix in ("module.", "unet.", "_orig_mod."): + if k.startswith(prefix): + k = k[len(prefix):] + break + cleaned[k] = v +unet.load_state_dict(cleaned, strict=False) # strict=False: a few shadow keys may be absent +unet.eval() ``` +See `src/inference.py:load_ema_unet()` for the canonical loader used in production. + --- ## Training Reproduction @@ -240,18 +254,30 @@ python src/inference.py \ ### Python API ```python -from src.generate import generate_images - -images = generate_images( - prompts=["a beautiful sunset over mountains"], - checkpoint_path="checkpoints/sd_epoch_042.pt", - num_steps=50, - guidance_scale=7.5, - seed=42, +import sys +sys.path.insert(0, "src") + +import torch +from transformers import CLIPTokenizer +from generate import load_model, generate + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = load_model("checkpoints/sd_epoch_042.pt", device) +tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + +images = generate( + model = model, + tokenizer = tokenizer, + prompts = ["a beautiful sunset over mountains"], + num_steps = 50, + guidance_scale = 7.5, + seed = 42, + output_path = "output.png", ) -images[0].save("output.png") ``` +Note: `generate()` is the function in `src/generate.py`. It takes a loaded `StableDiffusionModel`, not a checkpoint path — that's what `load_model()` is for above. + See [docs/inference.md](docs/inference.md) for all options. --- diff --git a/docs/blog_post.md b/docs/blog_post.md index 4a035ca..9d88afe 100644 --- a/docs/blog_post.md +++ b/docs/blog_post.md @@ -484,10 +484,10 @@ Higher than 9.0 introduces oversaturation and CFG artefacts (waxy skin, blown hi Negative prompts replace the unconditional embedding with something like `"blurry, low quality, distorted, deformed"`. They're the fastest way to prune common generative failure modes. -**Implementation note for the repo:** I have two inference scripts. +**Implementation note for the repo:** I have two inference scripts. Both wire negative prompts end-to-end in the current release: -- `SD_ImageGen.py` (CUDA) — properly wires negative prompts through `generate(..., negative_prompts=...)`. -- `inference.py` (CUDA/MPS, Apple Silicon friendly) — currently parses `--negative` but doesn't thread it into `generate()`; the unconditional branch still uses the empty string. **If you want true negative-prompt CFG, use `SD_ImageGen.py` until I fix that.** +- `SD_ImageGen.py` (CUDA) — `generate(..., negative_prompts=...)` parameter. +- `inference.py` (CUDA/MPS, Apple Silicon friendly) — `--negative` flag is parsed in `main()` and passed to `generate()` as `negative_prompts`; both the per-sample and broadcast (single negative for the whole batch) cases work. ### 6.4 Live vs EMA — The A/B Test That Mattered Most @@ -520,14 +520,7 @@ pred_x0 = pred_x0.clamp(-1.0, 1.0) # ← THIS It looks reasonable. It is catastrophically wrong. **SD latents are not in `[-1, 1]`** — their standard deviation is ≈ 4.0. Clamping decapitates the signal. -I removed it at inference and the images snapped into focus. - -```python -# in inference.py: monkey-patch the scheduler before sampling -SD_Model.DDIMScheduler.step = _fixed_ddim_step # the no-clamp version -``` - -**Caveat:** the clamp is still present in `SD_Model.py:736` because removing it changes behaviour for any old script that imports the class directly. The inference scripts patch it out at runtime. The proper long-term fix is to delete that line. +I made the clamp **opt-in** and turned it off by default. The current `DDIMScheduler.__init__` takes a `clamp_pred_x0: bool = False` parameter; passing `True` reproduces the old training-time behaviour. Both `inference.py` and `SD_ImageGen.py` default to `False`, so the wrong behaviour is no longer reachable from the supported code paths. **Lesson:** Never assume latent distributions match pixel distributions. @@ -571,8 +564,6 @@ If I started over tomorrow: 2. **Start with a lower Min-SNR γ (2.0–2.5) carefully.** For aesthetic-heavy data, lower γ can speed convergence — but test it visually, not just by loss. 3. **Test inference every single epoch.** I'd have caught the latent clamp bug in week one instead of week six. 4. **Validate with EMA from the very first epoch.** Don't ship a fix you forgot to backfill into your visualisations. -5. **Delete the training-time `pred_x0.clamp`** instead of monkey-patching around it. Make the wrong version unreachable. -6. **Wire `--negative` through `inference.py:generate()` properly** so both inference scripts behave identically. --- @@ -591,25 +582,61 @@ If I started over tomorrow: ``` StableDiffusion/ -├── SD_Model.py # UNet + VAE/CLIP wrappers + DDPM/DDIM schedulers -├── SD_Train.py # 2× RTX 5090 DDP + BF16 training loop -├── SD_ImageGen.py # CUDA inference (full negative-prompt CFG) -├── inference.py # CUDA/MPS inference (DDIM clamp monkey-patched) -├── encode_latents.py # 4-stage VAE → fp16 .npy pipeline -├── 01_download_metadata.py # LAION parquet snapshot -├── 01b_download_diffusiondb.py # DiffusionDB → 512×512 tar shards -├── 01c_download_journeydb_images.py # JourneyDB → 512×512 tar shards -├── 02_filter_metadata.py # aesthetic / CLIP / watermark / NSFW / dedup -├── 03_download_images.py # img2dataset LAION downloader -├── 03_build_hf_dataset.py # DiffusionDB/JourneyDB → Arrow HF dataset -├── 04_preprocess_to_cache.py # Tars → parquet (image_key + CLIP tokens) -├── 05_build_hf_dataset.py # Parquet → Arrow HF dataset -├── sd_epoch_042.pt # Released checkpoint (~12.5 GB) -├── sd-val-imgs/ # val_epoch_001..043.png (live-weight grids) -├── sd-logs/ # captured training.log / output*.log -└── generated_images/ # curated epoch-42 renders +├── src/ # Core implementation +│ ├── model.py # UNet + DDPM/DDIM schedulers (refactored) +│ ├── SD_Model.py # Legacy single-file module (kept for reproducibility) +│ ├── SD_Model_v2.py # Earlier refactor (experimental) +│ ├── SD_Model_scratch.py # Throwaway early prototype +│ ├── train.py # 2× RTX 5090 DDP + BF16 training loop (refactored) +│ ├── SD_Train.py # Training loop (the one that produced the checkpoint) +│ ├── SD_Train_v2.py # Earlier refactor +│ ├── inference.py # CUDA/MPS inference (negative prompts wired) +│ ├── SD_ImageGen.py # CUDA inference (negative prompts wired) +│ ├── generate.py # Programmatic generation API +│ ├── encode_latents.py # VAE pre-encoding → .npy +│ └── encode_pipeline.py # Data-parallel latent encoder (2-GPU) +├── data_pipeline/ # LAION-2B + DiffusionDB + JourneyDB pipeline +│ ├── 01_download_metadata.py # LAION parquet snapshot +│ ├── 01b_download_diffusiondb.py # DiffusionDB → 512×512 tar shards +│ ├── 01c_download_journeydb_images.py # JourneyDB → 512×512 tar shards +│ ├── 02_filter_metadata.py # aesthetic / CLIP / watermark / NSFW / dedup +│ ├── 03_download_images.py # img2dataset LAION downloader +│ ├── 03_build_hf_dataset.py # DiffusionDB/JourneyDB → Arrow HF dataset +│ ├── 04_preprocess_to_cache.py # Tars → parquet (image_key + CLIP tokens) +│ ├── 05_build_hf_dataset.py # Parquet → Arrow HF dataset +│ └── 06_filter_dataset.py # Final aesthetic / dedup pass +├── configs/ +│ └── config.py # Dataclass-based configuration +├── docs/ # Architecture, training, data-pipeline, inference guides +│ ├── architecture.md +│ ├── training-loop.md +│ ├── data-pipeline.md +│ ├── inference.md +│ ├── blog_post.md # (this file) +│ └── images/ # Hero collage, sample gallery, architecture diagram +├── tests/ # CPU smoke tests +│ ├── test_unet_forward.py +│ └── test_ddim_step.py +├── scripts/ +│ └── download_checkpoint.py # Hugging Face Hub checkpoint downloader +├── results/ # Training artifacts +│ ├── samples/ # Curated epoch-42 renders +│ ├── loss_curve.csv +│ └── training_status.md +├── assets/ # Architecture diagram, plots +├── sd_epoch_042.pt # Released checkpoint (hosted on HF Hub, not in repo) +├── LICENSE # MIT +├── CITATION.cff +├── README.md +├── requirements.txt +├── .github/workflows/ci.yml # Lint + smoke-test on every push +└── .env.example ``` +The released checkpoint lives at +[`atandra2000/sd-from-scratch-v1`](https://huggingface.co/atandra2000/sd-from-scratch-v1) on +the Hugging Face Hub — download it with `python scripts/download_checkpoint.py`. + --- ## Final Thoughts @@ -621,7 +648,7 @@ Training a Stable Diffusion model from scratch was one of the most rewarding eng - knowing which optimisations compose and which ones explode when you stack them, - and trusting visual validation over a loss number that lies to you. -If you're considering this: do it. But go in with your eyes open. The transformers will be fine. It's the JPEG decoder, the pinned buffer, the EMA decay and the one clamp at line 736 that will decide whether your model converges. +If you're considering this: do it. But go in with your eyes open. The transformers will be fine. It's the JPEG decoder, the pinned buffer, the EMA decay and the latent-space distribution mismatch that will decide whether your model converges. ---