diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..f4c29b3 --- /dev/null +++ b/.env.example @@ -0,0 +1,11 @@ +# Hugging Face Hub (write scope) — used by scripts/download_checkpoint.py and +# any local fine-tune resume helpers that pull/push from atandra2000/sd-from-scratch-v1. +HF_TOKEN= + +# Weights & Biases — used by SD_Train.py for training-time monitoring. +WANDB_API_KEY= + +# Optional: override paths if you don't use the defaults +# CHECKPOINT_DIR=./checkpoints +# LATENT_DIR=./laion_latents +# HF_HOME=./.hf-cache diff --git a/.gitignore b/.gitignore index 219766f..cfae5ba 100644 --- a/.gitignore +++ b/.gitignore @@ -1,18 +1,37 @@ +# Python __pycache__/ *.pyc *.pyo +*.pyd .Python *.egg-info/ dist/ build/ +.pytest_cache/ +.ruff_cache/ +.mypy_cache/ +.coverage +htmlcov/ + +# Virtual environments .env +.env.local +.env.*.local .venv venv/ +diffusion/ + +# Secrets (defensive — never commit raw token files) +secrets/ +*.pem +*.key # Checkpoints and model weights *.pt *.pth +*.ckpt *.safetensors +*.bin checkpoints/ outputs/ @@ -21,13 +40,29 @@ outputs/ *.parquet *.tar *.zip +*.tgz +*.h5 +*.hdf5 +*.arrow +*.lance laion_*/ +vggface_*/ +coco_*/ +dm_*/ +p7_*/ /workspace/ -# Training logs and wandb +# HF/HuggingFace caches +.hf-cache/ +.cache/huggingface/ + +# Training logs and W&B *.log +logs/ wandb/ training.log +runs/ +tb_logs/ # Generated images (exclude training validation grids, keep curated samples) val_epoch_*.png @@ -35,13 +70,20 @@ results/*.png !results/samples/ !results/samples/*.png !assets/*.png +!docs/images/ # Jupyter .ipynb_checkpoints/ *.ipynb -# IDE +# IDE / editor .idea/ .vscode/ *.swp +*.swo +*~ + +# OS .DS_Store +Thumbs.db +desktop.ini diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 0000000..389ecf9 --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,39 @@ +cff-version: 1.2.0 +title: "Stable Diffusion from Scratch (SD-From-Scratch v1)" +message: "If you use this software or the released checkpoint, please cite it as below." +type: software +authors: + - given-names: Atandra + family-names: Bharati + email: atandrabharati@gmail.com +repository-code: "https://github.com/atandra2000/StableDiffusion" +url: "https://github.com/atandra2000/StableDiffusion" +abstract: >- + A Stable Diffusion 1.x-class latent diffusion model implemented and + trained entirely from scratch in PyTorch on 2× RTX 5090 (Blackwell) GPUs. + Provides the full UNet (~860M params), DDPM/DDIM schedulers, LAION data + pipeline, dual-GPU latent encoding, DDP+BF16 training loop, and + CUDA/MPS inference. Released checkpoint sd_epoch_042.pt is hosted on + Hugging Face Hub at atandra2000/sd-from-scratch-v1. +license: MIT +version: "1.0.0" +date-released: 2026-06-02 +keywords: + - stable-diffusion + - latent-diffusion + - diffusion-models + - pytorch + - text-to-image + - generative-ai + - from-scratch + - rtx-5090 + - blackwell + - ddp +preferred-citation: + type: software + authors: + - given-names: Atandra + family-names: Bharati + title: "SD-From-Scratch v1: A Stable-Diffusion-class latent diffusion model trained from scratch on dual RTX 5090s" + year: 2026 + url: "https://huggingface.co/atandra2000/sd-from-scratch-v1" diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..e2862a0 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Atandra Bharati + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md index 065a6d5..ccc795e 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,69 @@ # Stable Diffusion from Scratch -A full-stack implementation of Stable Diffusion trained on LAION-2B-en-aesthetic, built entirely from scratch in PyTorch. The model implements the complete latent diffusion pipeline — from raw image data through VAE encoding, CLIP text conditioning, and a custom UNet denoising model — without relying on any high-level diffusion library. +[![GitHub](https://img.shields.io/badge/GitHub-atandra2000/StableDiffusion-181717?style=flat&logo=github)](https://github.com/atandra2000/StableDiffusion) +[![Hugging Face Model](https://img.shields.io/badge/%F0%9F%A4%97%20Model-sd--from--scratch--v1-FFD21E)](https://huggingface.co/atandra2000/sd-from-scratch-v1) +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE) +[![Python 3.11+](https://img.shields.io/badge/python-3.11%2B-blue)](https://www.python.org/) +[![W&B Report](https://img.shields.io/badge/W%26B-Training%20Logs-FFBE00?logo=weightsandbiases)](https://wandb.ai/atandrabharati-self/stable-diffusion) -**Hardware:** 2× RTX 5090 (Blackwell, cc 10.x, 32 GB VRAM each) on RunPod -**Dataset:** LAION-2B-en-aesthetic (~12M images pre-training) + curated fine-tuning subset -**Status:** Training ongoing — 42 epochs, 232K steps, best loss 0.0947 (epoch 16) -**W&B:** [atandrabharati-self/stable-diffusion](https://wandb.ai/atandrabharati-self/stable-diffusion) +A full-stack Stable Diffusion 1.x-class latent diffusion model — **built entirely from scratch in PyTorch, trained on 2× RTX 5090 (Blackwell) GPUs.** Every component (UNet, DDPM/DDIM, VAE pipeline, CLIP conditioning, data pipeline, DDP training loop) is hand-implemented; no `diffusers`, no `compel`, no black boxes. -![Architecture Overview](assets/architecture_overview.png) +**Checkpoint** → [atandra2000/sd-from-scratch-v1](https://huggingface.co/atandra2000/sd-from-scratch-v1) (12.5 GB, sd_epoch_042.pt) + +--- + +## Quick Start + +```bash +# 1. Download the checkpoint +pip install huggingface_hub +python scripts/download_checkpoint.py + +# 2. Run inference +pip install torch torchvision transformers Pillow +python src/inference.py --prompt "a cinematic shot of a mountain lake at sunset" --checkpoint checkpoints/sd_epoch_042.pt +``` + +See [docs/inference.md](docs/inference.md) for advanced usage (negative prompts, batch mode, DDIM parameters). + +--- + +## Repository Layout + +``` +├── src/ # Core implementation +│ ├── model.py # UNet (~860M params), DDPM/DDIM schedulers +│ ├── train.py # DDP + BF16 training loop +│ ├── inference.py # Apple Silicon + CUDA inference +│ ├── encode_latents.py # VAE pre-encoding for training +│ ├── encode_pipeline.py # Data-parallel latent encoder (2-GPU) +│ ├── generate.py # Programmatic generation API +│ ├── SD_ImageGen.py # Alternative inference script (CLI) +│ ├── SD_Model.py # Legacy model (kept for reproducibility) +│ └── SD_Train.py / SD_Train_v2.py # Legacy training scripts +├── data_pipeline/ # LAION-2B data processing +│ ├── 01_download_metadata.py → 06_filter_dataset.py +├── configs/ +│ └── config.py # Dataclass-based configuration +├── tests/ # CPU smoke tests +│ ├── test_unet_forward.py +│ └── test_ddim_step.py +├── docs/ # Documentation +│ ├── architecture.md # Model architecture deep-dive +│ ├── training-loop.md # Training procedure +│ ├── data-pipeline.md # Data pipeline walkthrough +│ ├── inference.md # Inference guide +│ ├── blog_post.md # Medium-style write-up +│ └── images/ # Diagrams and samples +├── scripts/ +│ └── download_checkpoint.py +├── results/samples/ # Curated output samples +├── assets/ # Architecture diagram, plots +├── requirements.txt +├── LICENSE # MIT +├── CITATION.cff # Citation metadata +└── .env.example # Environment variable template +``` --- @@ -44,10 +100,10 @@ IMAGE │ │ Stage 0 (64×64): 320 ch │ — no attention │ Stage 1 (32×32): 640 ch │ ← SpatialTransformer (cross-attn) │ Stage 2 (16×16): 1280 ch │ ← SpatialTransformer (cross-attn) -│ Stage 3 (8×8): 1280 ch │ ← SpatialTransformer (cross-attn) -│ Bottleneck (8×8): 1280 ch │ ← attn + resblock +│ Stage 3 (8×8): 1280 ch │ ← SpatialTransformer (cross-attn) +│ Bottleneck (8×8): 1280 ch │ ← attn + resblock │ Decoder: │ -│ Stage 3 (8×8): 1280 ch │ ← SpatialTransformer (cross-attn) +│ Stage 3 (8×8): 1280 ch │ ← SpatialTransformer (cross-attn) │ Stage 2 (16×16): 1280 ch │ ← SpatialTransformer (cross-attn) │ Stage 1 (32×32): 640 ch │ ← SpatialTransformer (cross-attn) │ Stage 0 (64×64): 320 ch │ — no attention @@ -58,455 +114,172 @@ IMAGE │ │ MSE Loss: ||ε̂ − ε||² │ - ┌───────▼───────┐ - │ DDIM Sampler │ inference: 30 steps (vs 1000 DDPM) - └───────┬───────┘ - ▼ -┌──────────────────────────────────┐ -│ VAE Decoder (frozen) │ (B,4,64,64) → (B,3,512,512) -└──────────────────────────────────┘ - │ - ▼ - Generated Image (512×512) -``` - ---- - -## Key Design Decisions - -| Design Choice | Implementation | Rationale | -|---------------|---------------|-----------| -| **Latent diffusion** | Operate in VAE's 4×64×64 space | 64× cheaper than pixel-space diffusion | -| **Frozen VAE + CLIP** | No gradients, no optimiser state | Reuse strong pretrained representations | -| **Epsilon prediction** | MSE loss on noise ε | Empirically more stable than x₀ or v-prediction | -| **Scaled-linear β schedule** | `β = linspace(√β_start, √β_end)²` | Better image quality than linear for latent diffusion | -| **DDIM inference** | 30 deterministic steps | Same quality as 1000-step DDPM, 33× faster | -| **EMA** | decay=0.9999, warmup-corrected | Smoother weights → better generation quality | -| **Latent pre-encoding** | Encode all images once, cache to RAM | Eliminates VAE from training loop entirely | -| **bfloat16 + torch.compile** | `mode="max-autotune"` on UNet | Best throughput on Blackwell (no GradScaler needed) | -| **DistributedDataParallel** | DDP + NCCL backend, `torchrun` launcher | True process-per-GPU parallelism; faster than DataParallel | -| **Min-SNR loss weighting** | γ=5, weight = min(SNR, γ)/SNR | Balances training signal across easy/hard timesteps | -| **EMA on GPU** | decay=0.9999, maintained on GPU | Eliminates CPU↔GPU copies; warmup-corrected | -| **Classifier-free guidance** | scale=7.5, concat uncond+cond | 2× UNet forward per step; strong prompt adherence | - ---- - -## Loss Function - -The UNet is trained with the **epsilon-prediction MSE** objective from DDPM (Ho et al., 2020): - + ▽ (backprop) ``` -L = E_{t, z₀, ε} [ ||ε − ε_θ(√ᾱ_t · z₀ + √(1−ᾱ_t) · ε, t, ctx)||² ] -``` - -where: -- `z₀` — clean VAE latent, shape `(B, 4, 64, 64)` -- `ε ~ N(0, I)` — sampled Gaussian noise -- `ᾱ_t` — cumulative noise product at timestep `t` -- `ε_θ` — UNet denoiser conditioned on timestep `t` and CLIP context `ctx` -The noisy latent is constructed via the forward diffusion process: -``` -z_t = √ᾱ_t · z₀ + √(1−ᾱ_t) · ε -``` +For a detailed architectural walkthrough, see [docs/architecture.md](docs/architecture.md). --- -## DDIM Inference +## Training Summary -DDIM (Song et al., 2020) enables deterministic generation in 25–50 forward passes instead of 1000. The denoising update at each step: +| Stage | Epochs | Steps | Best Loss | LR | Notes | +|-------|--------|-------|-----------|----|-------| +| **Pre-training** | 1–10 | 136K | **0.1247** | 1e-4 → 1e-5 | Cosine decay, Min-SNR γ=5→2.5 | +| **Fine-tuning** | 11–42 | 96K | **0.0947** (ep 16) | 1e-5 | EMA decay=0.9999, CFG dropout=0.05 | +| **Final** | 42 | 232K | 0.1212 | — | Released checkpoint | -``` -x̂₀ = (x_t − √(1−ᾱ_t) · ε_θ) / √ᾱ_t # predict clean latent -x_{t−1} = √ᾱ_{t−1} · x̂₀ + √(1−ᾱ_{t−1}) · ε_θ # deterministic step (η=0) -``` +- **Hardware:** 2× RTX 5090 (Blackwell, cc 10.x, 32 GB VRAM each) on RunPod +- **Dataset:** LAION-2B-en-aesthetic (~12M images, filtered to aesthetic ≥ 6.5) +- **Multi-GPU:** DDP (NCCL) with BF16 autocast, gradient accumulation (effective batch 96) +- **Loss:** Min-SNR-weighted MSE (γ=5.0 → 2.5 for fine-tuning) +- **EMA:** Polyak decay 0.9999, shadow weights stored in checkpoint -Classifier-free guidance combines conditional and unconditional predictions: -``` -ε_guided = ε_uncond + s · (ε_cond − ε_uncond) # s = guidance_scale = 7.5 -``` +Full loss curves and per-epoch breakdown: [summary.md](summary.md) --- -## UNet Implementation +## Checkpoints -### ResNet Block with Timestep Conditioning -```python -class ResNetBlock(nn.Module): - def __init__(self, in_ch, out_ch, t_dim): - # Time embedding projected into channel space (FiLM conditioning) - self.t_proj = nn.Linear(t_dim, out_ch) - - def _forward(self, x, t_emb): - h = self.conv1(self.act(self.norm1(x))) - h = h + self.t_proj(self.act(t_emb))[:, :, None, None] # broadcast - h = self.conv2(self.act(self.norm2(h))) - return h + self.skip(x) -``` +| Download source | Size | Format | Notes | +|---|---|---|---| +| [Hugging Face Hub](https://huggingface.co/atandra2000/sd-from-scratch-v1) | 12.5 GB | PyTorch `.pt` | Contains `ema_state_dict` + `unet_state_dict`, optimizer, LR scheduler | +| GitHub Releases | — | — | Coming in v1.1 | -### Cross-Attention for Text Conditioning -```python -class CrossAttention(nn.Module): - # Q from image features, K/V from CLIP text embeddings - def forward(self, x, ctx): - q = self.to_q(x) - k, v = self.to_k(ctx), self.to_v(ctx) - # Scaled dot-product attention (Flash Attention via SDPA) - out = F.scaled_dot_product_attention(q, k, v) - return self.proj(out) -``` +### Loading from Python -### Zero-Init Output Projection -All final output projections (UNet `conv_out`, `TransformerBlock` output) are zero-initialized: ```python -nn.init.zeros_(self.conv_out.weight) -nn.init.zeros_(self.conv_out.bias) -``` -This ensures the network starts by predicting zero noise — a stable initialization that prevents early training collapse. - -### Gradient Checkpointing (Optional) -```python -# Enable with --grad_ckpt flag to save ~30% VRAM -model.enable_gradient_checkpointing() -``` -Each `UNetResBlock` uses `torch.utils.checkpoint` with `use_reentrant=False`. - ---- - -## Data Pipeline - -The full 5-step pipeline produces a ready-to-train HuggingFace dataset from raw LAION metadata: - -``` -Pre-training data (LAION-2B-en-aesthetic, ~12M images): - Step 1: 01_download_metadata.py LAION-2B-en-aesthetic parquets via HF Hub - ↓ - Step 2: 02_filter_metadata.py Quality filters (aesthetic ≥ 6.5, CLIP sim ≥ 0.28, - ↓ resolution ≥ 512px, no watermarks/NSFW) → ~12M images - Step 3: 03_download_images.py img2dataset: parallel download + WebDataset shards - ↓ 16 processes × 64 threads, incremental resume - Step 4: 04_preprocess_to_cache.py Extract image_key + CLIP-tokenized captions - ↓ (images stay in .tar shards — not duplicated) - Step 5: 05_build_hf_dataset.py Merge batches → HuggingFace Dataset - train/val split, shuffle, save_to_disk - -Fine-tuning data (DiffusionDB + JourneyDB high-quality subset): - Step 1b: 01b_download_diffusiondb.py DiffusionDB: 500 shards, ~2M Stable Diffusion images - Step 1c: 01c_download_journeydb_images.py JourneyDB: 10 archives, ~210K Midjourney images - ↓ Both converted to WebDataset tar format at 512px - → same Steps 2–5 as above (filter, preprocess, build HF dataset) -``` - -### Filtering Criteria (Step 2) - -| Filter | Threshold | Reason | -|--------|-----------|--------| -| `aesthetic_score` | ≥ 6.5 | Top ~2% of LAION-2B — high visual quality | -| `clip_similarity` | ≥ 0.28 | Caption must describe the image content | -| `width`, `height` | ≥ 512px | No upscaling — prevents blurry training signal | -| Aspect ratio | 0.5 – 2.0 | Avoid extreme crops of portraits/panoramas | -| Caption length | 20–300 chars | Informative but not CLIP-truncated | -| `pwatermark` | < 0.5 | Prevents model from generating watermarks | -| NSFW | `UNLIKELY` only | Clean training distribution | - ---- - -## Latent Pre-Encoding +import torch +from src.model import UNetModel, DDIMScheduler +from huggingface_hub import hf_hub_download -`src/encode_pipeline.py` uses **process isolation** for true dual-GPU parallelism: +checkpoint = hf_hub_download( + repo_id="atandra2000/sd-from-scratch-v1", + filename="sd_epoch_042.pt", +) -```python -# Each process gets exclusive access to one physical GPU -os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_physical_id) +ckpt = torch.load(checkpoint, map_location="cpu", weights_only=False) +unet_sd = ckpt["unet_state_dict"] -# After import, 'cuda:0' maps to the assigned physical GPU -device = torch.device("cuda:0") +unet = UNetModel(in_ch=4, out_ch=4, ch=320, res_blks=2, + attn_lvls=(1, 2, 3), ch_mults=(1, 2, 4, 4), + heads=8, ctx_dim=768) +unet.load_state_dict(unet_sd, strict=True) ``` -This avoids CUDA context sharing between processes and achieves near-linear GPU utilization scaling. The VAE encodes images in batches of 32, saving `(4, 64, 64)` float16 tensors as `.npy` files. - -At training time, `load_latent_cache()` loads all ~12M latent tensors into RAM in parallel using 16 threads, eliminating all disk I/O from the training loop. - --- -## Training - -### Training History - -| | Phase 1 — Pre-training | Phase 2a — LAION Fine-tuning | Phase 2b — DiffusionDB/JourneyDB | Phase 2c — Extended FT | -|---|---|---|---|---| -| **Epochs** | 1 – 10 | 11 – 17 | 18 – 21 | 22 – 42 | -| **Steps** | 0 – 136K | 136K – 152K | 152K – 181K | 181K – 232K | -| **Dataset** | LAION-2B-en-aesthetic (~12M images) | 212K filtered LAION subset | 705K DiffusionDB + JourneyDB | Filtered FT dataset | -| **LR** | 1e-4 → 1e-5 | 3e-6 | 3e-6 → 8e-6 | 2e-6 → 1e-6 | -| **End loss** | 0.1247 | **0.0947 (best)** | 0.1030 | 0.1212 | -| **Min-SNR γ** | 5 | 5 | 5 → 3.0 | 2 – 5 | -| **cfg_dropout** | 0 → 0.05 | 0.05 | 0.05 – 0.10 | 0.05 | +## Training Reproduction -**Recommended checkpoint: `sd_epoch_017.pt`** (step 151,836, loss 0.0947). - -Epochs 18–21 fine-tuned on 705K DiffusionDB/JourneyDB images. LR was raised to 8e-6 in epochs 19–21 which proved too aggressive — **mode collapse** caused the model to output faces for all prompts. Epoch 17 is the clean working baseline. - -Phase 2c (epochs 22–42) resumes from epoch 17 with lr=2e-6, warmup 1000 steps, and a filtered dataset with 223K face/human prompts removed via `06_filter_dataset.py`. Training is ongoing. - -Full loss curve data (3,493 points): [`results/loss_curve.csv`](results/loss_curve.csv) -Full training log: [`results/training_status.md`](results/training_status.md) - -### Quickstart (RunPod) +### Data Pipeline ```bash -# 1. Clone and install -git clone https://github.com/atandra2000/StableDiffusion -cd StableDiffusion -pip install -r requirements.txt - -# 2. Run data pipeline (Steps 01–06) +# The full pipeline from raw LAION metadata → encoded latents: python data_pipeline/01_download_metadata.py -python data_pipeline/02_filter_metadata.py -python data_pipeline/03_download_images.py -python data_pipeline/04_preprocess_to_cache.py -python data_pipeline/05_build_hf_dataset.py -python data_pipeline/06_filter_dataset.py # optional: for FT filtered set - -# 3. Encode latents to disk (dual-GPU) -python src/encode_pipeline.py - -# 4. Train with DDP (2× RTX 5090) -torchrun --nproc_per_node=2 src/SD_Train.py \ - --cache_path /workspace/StableDiffusion/laion_hf_dataset \ - --latent_dir /workspace/StableDiffusion/laion_latents \ - --epochs 30 \ - --batch_size 24 \ - --grad_accum 2 \ - --lr 1e-4 \ - --use_wandb - -# 5. Generate images (CUDA) -python src/SD_ImageGen.py \ - --checkpoint checkpoints/sd_latest.pt \ - --prompts "a photorealistic sunset over mountain peaks" \ - --steps 50 --guidance 7.5 --seed 42 --output output.png +python data_pipeline/02_filter_metadata.py # aesthetic ≥ 6.5 +python data_pipeline/03_download_images.py # WebDataset shards +python data_pipeline/04_preprocess_to_cache.py # tokenize + augment +python data_pipeline/05_build_hf_dataset.py # HuggingFace Dataset +python src/encode_latents.py # VAE encode to .npy +python src/encode_pipeline.py # 2-GPU parallel encode ``` -### Resume Training +See [docs/data-pipeline.md](docs/data-pipeline.md) for the complete walkthrough. + +### Training ```bash +# Single node, 2× GPU (torchrun) +torchrun --nproc_per_node=2 src/train.py \ + --cache_path laion_hf_dataset/train \ + --latent_dir laion_latents \ + --epochs 42 \ + --batch_size 24 \ + --lr 1e-5 \ + --min_snr --min_snr_gamma 5.0 \ + --cfg_dropout 0.05 \ + --grad_ckpt \ + --memory_format channels_last + +# Resume from checkpoint torchrun --nproc_per_node=2 src/train.py \ - --cache_path /workspace/StableDiffusion/laion_hf_dataset \ - --latent_dir /workspace/StableDiffusion/laion_latents \ - --resume /workspace/checkpoints/sd_latest.pt + --resume checkpoints/sd_epoch_021.pt ``` -### Mid-Epoch Checkpointing (for interruptible pods) +--- + +## Inference + +### CLI ```bash -torchrun --nproc_per_node=2 src/train.py \ - --cache_path /workspace/StableDiffusion/hf_dataset_filtered/train \ - --latent_dir /workspace/StableDiffusion/latents_filtered/latents \ - --epochs 22 \ - --batch_size 24 \ - --grad_accum 2 \ - --lr 2e-6 \ - --warmup_steps 1000 \ - --cfg_dropout 0.05 \ - --min_snr \ - --min_snr_gamma 5.0 \ - --grad_ckpt \ - --save_steps 1000 \ - --ckpt_dir /workspace/StableDiffusion/phase1_v2_checkpoints \ - --resume /workspace/StableDiffusion/checkpoints/sd_epoch_017.pt \ - --use_wandb +# Single prompt (Apple Silicon or CUDA) +python src/inference.py \ + --prompt "a cosmic nebula with vibrant purples and blues" \ + --checkpoint checkpoints/sd_epoch_042.pt \ + --steps 50 --guidance 7.5 --seed 42 + +# With negative prompt +python src/inference.py \ + --prompt "a portrait of a woman" \ + --negative "blurry, low quality, deformed hands" \ + --checkpoint checkpoints/sd_epoch_042.pt + +# Batch mode +python src/inference.py \ + --batch prompts.txt \ + --output_dir ./outputs \ + --checkpoint checkpoints/sd_epoch_042.pt ``` -`--save_steps 1000` saves `sd_step_XXXXXXX.pt` + overwrites `sd_latest.pt` every 1000 global steps. If the pod is killed, resume from `sd_latest.pt` to continue mid-epoch from the last save point. - -### Generate with Trained Checkpoint +### Python API ```python -import torch -from src.model import StableDiffusionModel, PretrainedVAE, PretrainedCLIPTextEncoder -from src.model import UNetModel, DDPMScheduler, DDIMScheduler -from transformers import CLIPTokenizer - -# Load model -vae = PretrainedVAE("stabilityai/sd-vae-ft-mse").cuda() -clip = PretrainedCLIPTextEncoder("openai/clip-vit-large-patch14").cuda() -unet = UNetModel(in_ch=4, out_ch=4, ch=320, res_blks=2, - attn_lvls=(1,2,3), ch_mults=(1,2,4,4), heads=8, - t_dim=320, ctx_dim=768).cuda() - -ckpt = torch.load("checkpoints/sd_latest.pt") -unet.load_state_dict(ckpt["unet_state_dict"]) - -# Tokenize -tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") -prompt = "a photorealistic sunset over mountain peaks" -tokens = tokenizer(prompt, padding="max_length", max_length=77, - truncation=True, return_tensors="pt").to("cuda") - -# Generate (DDIM, 30 steps, CFG=7.5) -sched = DDIMScheduler() -sched.set_timesteps(30, device="cuda") -latents = torch.randn(1, 4, 64, 64, device="cuda") -ctx = clip(tokens.input_ids, tokens.attention_mask)[0].unsqueeze(0) -uncond = clip(tokenizer([""], ...)[0])[0].unsqueeze(0) - -for t in sched.timesteps: - noise_pred = unet(torch.cat([latents]*2), t.expand(2), torch.cat([uncond, ctx])) - noise_uncond, noise_cond = noise_pred.chunk(2) - guided = noise_uncond + 7.5 * (noise_cond - noise_uncond) - latents = sched.step(guided, t, latents) - -image = vae.decode(latents) # (1, 3, 512, 512) in [-1, 1] +from src.generate import generate_images + +images = generate_images( + prompts=["a beautiful sunset over mountains"], + checkpoint_path="checkpoints/sd_epoch_042.pt", + num_steps=50, + guidance_scale=7.5, + seed=42, +) +images[0].save("output.png") ``` ---- - -## Hyperparameter Reference - -| Parameter | Value | Notes | -|-----------|-------|-------| -| Image resolution | 512 × 512 | Native SD resolution | -| Latent resolution | 64 × 64 | 8× downsampled by VAE | -| Latent channels | 4 | VAE bottleneck | -| UNet base channels | 320 | SD 1.x standard | -| Channel multipliers | (1, 2, 4, 4) | → 320, 640, 1280, 1280 | -| ResBlocks per stage | 2 | Encoder and decoder | -| Attention heads | 8 | In all SpatialTransformers | -| Context dimension | 768 | CLIP ViT-L/14 output | -| DDPM timesteps | 1000 | Training schedule | -| β start / end | 0.00085 / 0.012 | Scaled-linear schedule | -| β schedule | `scaled_linear` | Better than linear for latents | -| DDIM steps | 30 | Inference | -| Guidance scale | 7.5 | Classifier-free guidance | -| Optimizer | AdamW | fused=True for throughput | -| Learning rate | 1e-4 | | -| Weight decay | 1e-2 | | -| LR warmup | 500 steps | Linear 1e-6 → 1e-4 | -| LR decay | CosineAnnealing | eta_min = lr × 1e-2 | -| Batch size (effective) | 96 | 24/GPU × 2 GPUs × 2 accum | -| EMA decay | 0.9999 | Warmup-corrected, maintained on GPU | -| Precision | bfloat16 | Blackwell native (no GradScaler) | -| Compilation | `torch.compile` | `max-autotune` mode | -| Min-SNR γ | 5 → 2.5 (pretrain) / 2–5 (FT) | Loss weighting by Hang et al. (2023) | -| cfg_dropout | 0.05–0.10 (FT phase) | Random caption drop for CFG training | -| Total steps | 232,235 | Pre-train 136K + Fine-tune 96K (ongoing) | -| Best loss | 0.0947 | Epoch 16, step 149,718 | -| Grad norm clip | 1.0 | Prevents gradient explosion | +See [docs/inference.md](docs/inference.md) for all options. --- -## Repository Structure +## Known Differences from Official Stable Diffusion -``` -StableDiffusion/ -├── src/ -│ ├── SD_Model.py # Production model: VAE (fp16), CLIP, UNet, DDPM/DDIM schedulers -│ ├── SD_Model_v2.py # Next-gen MM-DiT + dual CLIP + Rectified Flow (SD3-style) -│ ├── SD_Model_scratch.py # Educational scratch implementation -│ ├── SD_Train.py # DDP training loop for 2× RTX 5090, Min-SNR, EMA, BF16 -│ ├── SD_Train_v2.py # MM-DiT training (velocity prediction, logit-normal timesteps) -│ ├── SD_ImageGen.py # GPU inference: DDIM + CFG, negative prompts, grid output -│ ├── model.py # Core model (earlier iteration) -│ ├── train.py # Earlier training loop -│ ├── generate.py # GPU inference CLI -│ ├── inference.py # Apple Silicon / MPS inference (local generation) -│ ├── encode_pipeline.py # Dual-GPU VAE latent encoding -│ └── encode_latents.py # 3-stage pipeline encoder: shard prefetch + DMA + bfloat16 -├── data_pipeline/ -│ ├── 01_download_metadata.py # Download LAION-2B-en-aesthetic parquets -│ ├── 01b_download_diffusiondb.py # Download DiffusionDB (FT dataset) -│ ├── 01c_download_journeydb_images.py # Download JourneyDB subset (FT dataset) -│ ├── 02_filter_metadata.py # Quality filtering (aesthetic, CLIP, resolution) -│ ├── 03_build_hf_dataset.py # Hybrid HF Dataset build from parquet batches -│ ├── 03_download_images.py # img2dataset parallel image download -│ ├── 04_preprocess_to_cache.py # Tokenize captions, build hybrid cache -│ ├── 05_build_hf_dataset.py # Merge into HuggingFace Dataset (train/val split) -│ └── 06_filter_dataset.py # Remove celebrities/NSFW; hardlink filtered latents -├── configs/ -│ └── config.py # All hyperparameters in typed dataclasses -├── assets/ -│ ├── generate_plots.py # Architecture overview chart -│ └── architecture_overview.png -├── results/ -│ ├── training_status.md # Full training log: pre-training + fine-tuning (ep 1–42) -│ ├── loss_curve.csv # 3,493-point loss history across 232K steps -│ └── samples/ # Generated images from various checkpoints -│ ├── car.png, city.png, city_42.png, forest.png, forest_42.png -│ ├── fullbody.png, fullbody_2.png, fullbody_42.png -│ ├── landscape.png, landscape_42.png -│ ├── man.png, portrait.png, portrait_42.png -│ ├── portrait_centered.png, portrait_centered_2.png, portrait_centered_3.png -│ └── portrait_highcfg.png, cinematic.png, custom.png -├── .github/workflows/ -│ └── ci.yml # Lint + UNet forward pass smoke test -└── requirements.txt -``` +| Component | Implementation | Diffuser Reference | +|---|---|---| +| UNet | Full ~860M param SD 1.x UNet, custom `SpatialTransformer`, Flash SDP | `UNet2DConditionModel` | +| Scheduler (training) | DDPM with `scaled_linear` beta schedule | `DDPMScheduler` | +| Scheduler (inference) | DDIM with optional stochastic (eta) | `DDIMScheduler` | +| Text encoder | `CLIPTextModel` from transformers (frozen) | `CLIPTextModel` | +| VAE | `AutoencoderKL` from diffusers (frozen) | `AutoencoderKL` | +| Conditioning | Classifier-Free Guidance with `--negative` prompt | CFG | +| Multi-GPU | DDP (NCCL) with BF16 autocast | `accelerate` | +| Loss | Min-SNR-weighted MSE | — | --- -## Generated Samples - -Generated locally on **Apple Silicon (MPS)** using `src/inference.py` from the **epoch 17 checkpoint** (`sd_epoch_017.pt`, step 151,836, best loss 0.0947). DDIM 100 steps, CFG=7.5, seed=42. Load time ~95–106s, generation ~113–120s (1.1–1.2s/step on MPS). +## Citation -```bash -python3 src/inference.py \ - --checkpoint sd_epoch_017.pt \ - --prompt "a racing car cruising through forest roads" \ - --steps 100 --guidance 7.5 --seed 42 --output results/samples/custom.png - -python3 src/inference.py \ - --checkpoint sd_epoch_017.pt \ - --prompt "a man very happy lying down on the beach, cinematic and photorealistic" \ - --steps 100 --guidance 7.5 --seed 42 --output results/samples/cinematic.png - -python3 src/inference.py \ - --checkpoint sd_epoch_017.pt \ - --prompt "a man climbing mountain" \ - --steps 100 --guidance 7.5 --seed 42 --output results/samples/man.png +```bibtex +@software{bharati2026sdfromscratch, + author = {Atandra Bharati}, + title = {{SD-From-Scratch v1}: A Stable-Diffusion-class Latent Diffusion Model + Trained from Scratch on Dual {RTX} 5090s}, + year = {2026}, + url = {https://huggingface.co/atandra2000/sd-from-scratch-v1}, +} ``` -| Image | Prompt | Checkpoint | -|-------|--------|------------| -| ![custom](results/samples/custom.png) | "a racing car cruising through forest roads" | ep 17 | -| ![cinematic](results/samples/cinematic.png) | "a man very happy lying down on the beach, cinematic and photorealistic" | ep 17 | -| ![man](results/samples/man.png) | "a man climbing mountain" | ep 17 | -| ![city](results/samples/city.png) | city | ep 17 | -| ![forest](results/samples/forest.png) | forest | ep 17 | -| ![landscape](results/samples/landscape.png) | landscape | ep 17 | -| ![portrait](results/samples/portrait.png) | portrait | ep 17 | -| ![car](results/samples/car.png) | "a racing car cruising through forest roads" | ep 42 | -| ![portrait_highcfg](results/samples/portrait_highcfg.png) | portrait (high CFG) | ep 42 | -| ![portrait_centered](results/samples/portrait_centered.png) | portrait centered | ep 42 | -| ![fullbody](results/samples/fullbody.png) | full body portrait | ep 42 | -| ![city_42](results/samples/city_42.png) | city | ep 42 | -| ![forest_42](results/samples/forest_42.png) | forest | ep 42 | -| ![landscape_42](results/samples/landscape_42.png) | landscape | ep 42 | - ---- - -## Checkpoints - -Checkpoints are stored on Google Drive (~11.6 GB each): - -| Epoch | Phase | Global Step | Loss | Notes | Drive | -|-------|-------|-------------|------|-------|-------| -| 10 | Pre-training end | 136,279 | 0.1247 | | [Drive folder](https://drive.google.com/drive/folders/1EJdiLwaE6iMGksj9mr_CZkUF7RlXO9Wp) | -| 14 | LAION fine-tuning | 145,282 | 0.1257 | | [Drive folder](https://drive.google.com/drive/folders/1EJdiLwaE6iMGksj9mr_CZkUF7RlXO9Wp) | -| **17** | **LAION fine-tuning** | **151,818** | **0.0947** | **✅ Recommended for inference** | [Drive folder](https://drive.google.com/drive/folders/1EJdiLwaE6iMGksj9mr_CZkUF7RlXO9Wp) | -| 21 | DiffusionDB/JourneyDB | 181,177 | 0.1030 | ⚠️ Mode collapse — avoid for inference | [Drive folder](https://drive.google.com/drive/folders/1EJdiLwaE6iMGksj9mr_CZkUF7RlXO9Wp) | -| 42 | Extended fine-tuning | 232,235 | 0.1212 | Phase 2c — filtered dataset | [Drive folder](https://drive.google.com/drive/folders/1EJdiLwaE6iMGksj9mr_CZkUF7RlXO9Wp) | - -Each checkpoint contains: `unet_state_dict`, `ema_state_dict`, `optimizer_state_dict`, `scheduler_state_dict`, `epoch`, `global_step`, `best_loss`. - --- -## References +## License -- **LDM**: Rombach et al. (2022). *High-Resolution Image Synthesis with Latent Diffusion Models*. CVPR. -- **DDPM**: Ho et al. (2020). *Denoising Diffusion Probabilistic Models*. NeurIPS. -- **DDIM**: Song et al. (2020). *Denoising Diffusion Implicit Models*. ICLR. -- **CFG**: Ho & Salimans (2021). *Classifier-Free Diffusion Guidance*. -- **CLIP**: Radford et al. (2021). *Learning Transferable Visual Models from Natural Language Supervision*. ICML. -- **LAION**: Schuhmann et al. (2022). *LAION-5B: An Open Large-Scale Dataset for Training Next Generation Image-Text Models*. NeurIPS. -- **Min-SNR**: Hang et al. (2023). *Efficient Diffusion Training via Min-SNR Weighting Strategy*. ICCV. +MIT — see [LICENSE](LICENSE) for details. diff --git a/docs/architecture.md b/docs/architecture.md new file mode 100644 index 0000000..b1c1aad --- /dev/null +++ b/docs/architecture.md @@ -0,0 +1,122 @@ +# Model Architecture + +## Overview + +The model implements the full latent diffusion pipeline from Rombach et al. (2022), composed of four learned components and two schedulers. + +## Components + +### 1. VAE (Frozen) — `stabilityai/sd-vae-ft-mse` + +- **Encoder:** Down-convolves 512×512 RGB images to 64×64 latents (4 channels) +- **Decoder:** Up-convolves latents back to 512×512 RGB +- **Scale factor:** 0.18215 (multiply latents by this before UNet, divide after) +- **Parameters:** ~83M, frozen during training +- **Integration:** Loaded from diffusers `AutoencoderKL` via `PretrainedVAE` wrapper in `model.py` + +### 2. CLIP Text Encoder (Frozen) — `openai/clip-vit-large-patch14` + +- **Input:** 77 BPE tokens (truncated/padded) +- **Output:** `(B, 77, 768)` hidden states +- **Parameters:** ~123M, frozen during training +- **Integration:** Loaded from transformers `CLIPTextModel` via `PretrainedCLIPTextEncoder` wrapper + +### 3. UNet (Trainable) — ~860M Parameters + +The denoising backbone is a U-Net with spatial self-attention and cross-attention for text conditioning: + +``` + INPUT (4, 64, 64) + │ + ┌────────────┴────────────┐ + │ Conv2d 3×3 │ out=320 + └────────────┬────────────┘ + │ + ┌────────────┴────────────┐ + │ Block 0 (64×64) │ 320 ch, 2× ResNetBlock + │ No attention │ skip → decoder + └────────────┬────────────┘ + │ + ┌────────────┴────────────┐ + │ Downsample 2× │ + └────────────┬────────────┘ + ┌────────────┴────────────┐ + │ Block 1 (32×32) │ 640 ch, 2× ResNetBlock + │ ★ SpatialTransformer │ cross-attn on text (768) + └────────────┬────────────┘ + │ + ┌────────────┴────────────┐ + │ Downsample 2× │ + └────────────┬────────────┘ + ┌────────────┴────────────┐ + │ Block 2 (16×16) │ 1280 ch, 2× ResNetBlock + │ ★ SpatialTransformer │ cross-attn on text (768) + └────────────┬────────────┘ + │ + ┌────────────┴────────────┐ + │ Downsample 2× │ + └────────────┬────────────┘ + ┌────────────┴────────────┐ + │ Block 3 (8×8) │ 1280 ch, 2× ResNetBlock + │ ★ SpatialTransformer │ cross-attn on text (768) + └────────────┬────────────┘ + │ + ┌────────────┴────────────┐ + │ Bottleneck (8×8) │ 1280 ch, ResNet + Attn + └────────────┬────────────┘ + │ + ┌────────────────────────┼────────────────────────┐ + │ Mirror decoder with │ skip connections from │ + │ same block structure │ encoder │ + └────────────────────────┼────────────────────────┘ + │ + ┌────────────┴────────────┐ + │ Output Conv2d 3×3 │ out=4 (latent channels) + └────────────┬────────────┘ + │ + OUTPUT (4, 64, 64) +``` + +#### Key implementation details: + +- **SpatialTransformer:** GroupNorm → Conv 1×1 → multi-head self-attention → cross-attention (text as K,V) → Conv 1×1 → residual +- **ResNetBlock:** GroupNorm → SiLU → Conv 3×3 → GroupNorm → SiLU → dropout → Conv 3×3 + residual +- **Time conditioning:** Sinusoidal timestep embedding → MLP → added to each ResNetBlock (FiLM-style scale+shift) +- **Flash Attention:** Enabled via `torch.backends.cuda.enable_flash_sdp(True)` on Blackwell (cc ≥ 8.0) +- **Gradient checkpointing:** Saves ~40% activation memory by recomputing activations during backward (configurable via `--grad_ckpt`) + +### 4. DDPM Scheduler (Training) + +Defines the forward noising process: + +``` +z_t = √(ᾱ_t) · z_0 + √(1 − ᾱ_t) · ε where ε ~ N(0, I) +``` + +- **Beta schedule:** `scaled_linear` (β linearly spaced after square-root transformation) +- **Steps:** 1000 +- **Range:** β₁ = 0.00085, β₁₀₀₀ = 0.012 + +### 5. DDIM Scheduler (Inference) + +Deterministic reverse process (Song et al., 2020). Only 25–50 steps needed: + +``` +x̂_0 = (x_t − √(1−ᾱ_t) · ε_θ) / √(ᾱ_t) (clean latent estimate) +x_{t−1} = √(ᾱ_{t−1}) · x̂_0 + √(1 − ᾱ_{t−1}) · ε_θ (eta=0, deterministic) +``` + +**Note on `pred_x0.clamp(-1.0, 1.0)`:** The original 1000-step DDPM training code clamps the clean latent estimate to `[-1, 1]`. This is incorrect for inference — SD latents have a standard deviation of ~4, and clamping destroys signal quality. `model.py` now makes this clamp opt-in via `DDIMScheduler(clamp_pred_x0=False)` (default). The `inference.py` script and `SD_ImageGen.py` both set this to `False`. + +## Parameter Count Breakdown + +| Component | Parameters | Trainable | +|-----------|-----------|-----------| +| CLIP Text Encoder | 123M | ❌ | +| VAE (Encoder + Decoder) | 83M | ❌ | +| UNet | ~860M | ✅ | +| **Total pipeline** | **~1.07B** | **~860M** | + +## Memory Format + +The model uses `channels_last` memory format (NHWC) on CUDA for optimal convolution performance on Blackwell. For Apple Silicon (MPS) or AMD GPUs, pass `--memory_format contiguous` at the command line to disable this. diff --git a/docs/blog_post.md b/docs/blog_post.md new file mode 100644 index 0000000..4a035ca --- /dev/null +++ b/docs/blog_post.md @@ -0,0 +1,635 @@ +# I Trained Stable Diffusion From Scratch on 2× RTX 5090s — Here's What Actually Matters + +*A 48-epoch deep-dive into 860M parameters, 1.3M filtered images, and every footgun the textbooks don't warn you about.* + +--- + +> **TL;DR** +> +> Over a few months I trained a Latent Diffusion Model from scratch: +> +> - **860M-parameter UNet** (ch=320, ch_mults=(1,2,4,4)) — the full SD 1.x recipe +> - **2× RTX 5090** (Blackwell, 33.7 GB VRAM each), DDP + NCCL on RunPod +> - **48 epochs across 7 phases**, ~1.3M → 213k → 572k mixed dataset +> - **BF16 native, gradient checkpointing, Min-SNR γ, EMA decay 0.9999** +> - **Best loss: 0.0947 (epoch 16)**. Best images: epoch 42. +> +> The hardest problems had nothing to do with transformers, attention layers, or diffusion maths. They were: +> +> - corrupt JPEGs +> - VAE throughput +> - GPU DMA bottlenecks +> - catastrophic forgetting from sequential fine-tuning +> - broken latent distributions +> - and one line of code that silently destroyed image quality for days +> +> This article is everything I wish someone had told me before I started. + +*(Insert Image: Hero shot — side-by-side `val_epoch_002.png` (multicolored noise) vs `val_epoch_042.png` (photorealistic scenes).)* + +--- + +## 1. Why "From Scratch" Is Different + +Fine-tuning a pretrained SD is a weekend project. Training the UNet from random init is a different sport. You inherit every numerical instability, every dataset wart, and every CUDA quirk that the original authors solved silently before they shipped weights. + +If you're going to do this: + +- **Plan in phases**, not epochs. +- **Treat the loss curve with deep suspicion.** Visual coherence and MSE are only loosely correlated. +- **Budget 80 % of your time for data, 20 % for the model.** I wrote the UNet in two days. Building, filtering, and encoding the dataset took two weeks. + +--- + +## 2. The Architecture — Every Layer Matters + +Diffusion models don't operate on pixels; they operate in a compressed latent space. That single design choice is what makes an 860M-parameter model trainable on a consumer-grade pair of GPUs. + +### 2.1 The Big Picture + +Three frozen-or-trained components in a tight loop: + +1. **VAE** (frozen) — compresses 512×512×3 → 64×64×4 latents (8× spatial compression). +2. **CLIP text encoder** (frozen) — turns a prompt into a (77, 768) sequence. +3. **UNet** (trained) — given a noisy latent, a timestep, and the text embedding, predicts the noise. + +### 2.2 VAE: The Latent Translator + +I used `stabilityai/sd-vae-ft-mse`, kept in BF16 on Blackwell, and used `posterior.mean` rather than `posterior.sample()` for deterministic, cheaper encoding. + +The single most important constant in the entire codebase is the **VAE scale factor of 0.18215**: + +```python +latents = posterior.mean * 0.18215 # encode +decoded = vae.decode(latents / 0.18215) +``` + +This is the empirical standard deviation of the LAION-2B latent distribution. Skip it and your UNet trains on data with the wrong variance — the loss looks vaguely sensible but the model never converges. + +### 2.3 CLIP: The Language Bridge + +I used `openai/clip-vit-large-patch14`, also frozen. The UNet's cross-attention consumes the full `last_hidden_state` of shape **(B, 77, 768)** — not the pooled CLS vector. That sequence-level view is what allows cross-attention to "look at the word *red* when it draws the hat." + +### 2.4 The UNet — 860M Parameters of Denoising + +The exact SD 1.x topology: + +| Stage | Channels | Spatial | Self-attn | Cross-attn | +|---|---|---|---|---| +| 0 (in) | 320 | 64×64 | — | — | +| 1 | 640 | 32×32 | yes | yes | +| 2 | 1280 | 16×16 | yes | yes | +| 3 | 1280 | 8×8 | yes | yes | +| Bottleneck | 1280 | 8×8 | yes (1st only) | yes | +| Decoder | mirror | mirror | mirror | mirror | + +Concretely: + +```python +unet = UNetModel( + in_ch=4, out_ch=4, + ch=320, + res_blks=2, + ch_mults=(1, 2, 4, 4), + attn_lvls=(1, 2, 3), # no attention at full 64×64 + heads=8, + t_dim=320, ctx_dim=768, +) +``` + +Attention runs through PyTorch's `scaled_dot_product_attention` with **Flash SDP** and the **memory-efficient kernel** both enabled; the math fallback is disabled to make sure I don't silently fall off the fast path. + +### 2.5 Zero-Initialisation: The Calm Start + +Every residual block's final conv, every attention output projection, every MLP's last linear, and the UNet's `conv_out` are **zero-initialised**: + +```python +nn.init.zeros_(self.conv2.weight) +nn.init.zeros_(self.conv2.bias) +``` + +At step 0, the network predicts zero noise. The first gradient updates start the model in a region where activations are well-conditioned. Without this, large-channel residual blocks blow up in the first few hundred steps. + +### 2.6 Schedulers — DDPM for Training, DDIM for Inference + +I use the SD 1.x **scaled-linear** beta schedule: + +```python +betas = torch.linspace(0.00085**0.5, 0.012**0.5, 1000) ** 2 +``` + +This concentrates small betas near `t=0` (fine-detail noise) and grows them quickly toward `t=999`. + +- **Training:** DDPM, 1000 steps, ε-prediction objective. +- **Inference:** DDIM (η=0, deterministic). 25 steps is fine for scenes; **100 steps is required to keep facial detail crisp**. + +*(Insert Asset: Architecture diagram — VAE → UNet → CLIP wiring with shapes.)* + +> ### Engineering Takeaways +> - The **VAE scale factor (0.18215) is non-negotiable**. +> - **Zero-init** every output projection — your training stability comes from boring places. +> - Use `torch.nn.functional.scaled_dot_product_attention` and **explicitly enable Flash + mem-efficient SDP**. +> - Train DDPM, sample DDIM. + +--- + +## 3. The Data Pipeline — The Real Work + +If you take one lesson from this whole project, take this: **brutal filtering beats raw scale**. Every "ugly" image in your batch yanks the gradient sideways. + +### 3.1 Two-Stage Filtering + +I started from `laion/laion2B-en-aesthetic` (~2 B URL/caption rows) and ran two filtering passes. + +**Stage 1 — Broad pretraining (~1.3 M images kept)** + +| Filter | Threshold | +|---|---| +| Aesthetic score | ≥ 6.5 | +| CLIP similarity | ≥ 0.28 | +| Min resolution | 512×512 | +| Aspect ratio | 0.5 – 2.0 | +| Watermark prob. | < 0.15 (script default) | +| NSFW | "UNLIKELY" only | +| Caption length | 20 – 300 chars | +| Dedup | URL-level | + +The script (`02_filter_metadata.py`) reads, normalises column names across LAION's inconsistent schemas, then writes filtered parquet files. Survival rate from LAION-2B-en: roughly **0.065 %**. + +**Stage 2 — Rigorous refinement (213,458 images)** + +After Phase 1, I cranked the thresholds: + +- Aesthetic ≥ 7.5 +- CLIP similarity ≥ 0.30 +- Watermark probability < 0.15 (kept tight) + +It feels obscene to throw away 99.9 % of your data. Do it anyway. This subset produced the single biggest visual jump of the entire project. + +### 3.2 Other Datasets + +The full data mix: + +| Dataset | Images | Purpose | Phases | +|---|---|---|---| +| LAION-2B-en aesthetic | 1.3 M / 213k filtered | Broad pretraining + refinement | 1, 2, 6, 7 | +| DiffusionDB | ~205k | Synthetic prompt diversity | 3, 7 | +| JourneyDB | ~277k | Midjourney-style aesthetics | 3, 7 | +| VGGFace2 | 51k | Face anatomy | 4, 6, 7 | +| COCO (detection-datasets) | 59k | Full-body / scene integration | 5, 6, 7 | + +DiffusionDB and JourneyDB came as zipped image dumps — I rebuilt them into standard `*.tar` WebDataset shards with 512×512 centre-cropped JPEGs and matching `*.txt` captions (`01b_*.py` and `01c_*.py`) so they could feed the same downstream pipeline. + +### 3.3 The VAE Latent Encoding Pipeline (the real performance work) + +If you encode latents on the fly during training, your GPUs spend most of their time waiting for the VAE. The fix is to **pre-compute every latent once and stream them as `.npy` files** during training. + +`encode_latents.py` is a 4-stage concurrent pipeline: + +1. **Shard prefetch thread** — copies tar shards from the network mount (RunPod MFS) to local NVMe. A bounded queue (`SHARD_PREFETCH=3`) provides backpressure so `/tmp` never overflows. +2. **Extractor + decode submit thread** — opens each local tar, indexes JPEG members by stem, and fans decode tasks into a worker pool. +3. **Decode pool (16 threads, OpenCV)** — `cv2.imdecode` is ~3× faster than PIL. Each worker decodes, resizes preserving aspect ratio, centre-crops to 512×512, and normalises to `[-1, 1]`. +4. **GPU main thread** — packs results into a **pre-allocated pinned-memory staging buffer**, kicks off an async H2D copy on a **dedicated CUDA stream**, then runs `vae.encode()` under BF16 autocast. + +```python +staging = torch.empty(BATCH_SIZE, 3, 512, 512, dtype=torch.float32, pin_memory=True) +h2d_stream = torch.cuda.Stream(device=device) + +with torch.cuda.stream(h2d_stream): + pixel = staging[:n].to(device, non_blocking=True) +torch.cuda.current_stream(device).wait_stream(h2d_stream) + +with torch.autocast("cuda", dtype=torch.bfloat16): + latents = vae.encode(pixel) # (n, 4, 64, 64), fp16-saved +``` + +Page-locked memory means the H2D copy goes through **DMA**, bypassing the kernel's software bounce buffer. The dedicated stream overlaps that copy with the next batch's decode futures resolving on the CPU. + +I also `torch.compile`'d the VAE in `reduce-overhead` mode, which gave another 20–30 % throughput after the first batch's graph capture. End result: **~1.3 M images → ~48 GB of fp16 latents** (32 KB per file), saved as one `.npy` per image. Resume is `O(1)` via `os.listdir`. + +*(Insert Asset: Pipeline diagram — Shard Prefetch → Tar Extract → Parallel Decode → Pinned DMA → VAE → fp16 .npy.)* + +> ### Engineering Takeaways +> - **Quality dominates quantity.** Filter ruthlessly. +> - **Pre-compute latents.** Don't pay the VAE cost in your training hot loop. +> - **Pinned memory + dedicated CUDA stream + DMA** is how you actually saturate PCIe 5.0. +> - `torch.compile` is safe **outside** the training loop (encoding, evaluation). Inside the loop it's a trap (see §7). + +--- + +## 4. The Training Loop — Engineering for Scale + +Training 860M parameters is mostly plumbing. Get the plumbing right and the model trains. Get it wrong and you'll lose a week to non-deterministic crashes. + +### 4.1 Hardware & Distribution + +Two RTX 5090s (sm_120, 33.7 GB each), connected by PCIe 5.0, talking via NCCL. I use **DistributedDataParallel** rather than DeepSpeed/FSDP — for a model that fits in a single card's VRAM, DDP's all-reduce is the simplest correct choice. + +Key DDP knobs that mattered: + +```python +ddp_unet = DDP( + model.unet, + device_ids=[rank], + output_device=rank, + find_unused_parameters=False, # every param sees a gradient + gradient_as_bucket_view=True, # zero-copy bucket views +) +``` + +And on gradient accumulation steps, **skip the all-reduce**: + +```python +with ddp_unet.no_sync() if (step + 1) % grad_accum != 0 else nullcontext(): + loss.backward() +``` + +That single line saves ~50 % of inter-GPU traffic. + +### 4.2 VRAM Survival Guide + +Effective batch = **24 / GPU × 2 GPUs × 2 grad-accum = 96**. To fit it: + +1. **BF16 native autocast.** Blackwell does BF16 natively; no `GradScaler` needed because BF16's dynamic range matches FP32. + ```python + with torch.autocast("cuda", dtype=torch.bfloat16): + noise_pred = ddp_unet(noisy_latents, t, text_emb) + ``` +2. **Gradient checkpointing on every UNet residual block** (`use_reentrant=False`). Trades ~30 % compute for ~40 % VRAM. Without it, ch=320 won't fit at this batch size. +3. **TF32 + Flash + mem-efficient SDP enabled, math SDP disabled.** Saves both VRAM and compute. +4. **`PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`** to reduce fragmentation across long runs. +5. **Channels-last memory format** for the UNet. (Important caveat below — it briefly stopped working on sm_120 during Phase 4.) + +I observed peak reserved VRAM around **25.3 GB** during the heaviest epochs — comfortable on a 33.7 GB card. + +### 4.3 The "Zero I/O" Training Loop + +Because every latent is a 32 KB `.npy`, the entire 213k LAION cache (~7 GB) fits in RAM trivially. At startup, a 16-thread loader streams every file into a `dict[str, torch.Tensor]`: + +```python +LATENT_FRACTION = 1.0 +with ThreadPoolExecutor(max_workers=16) as pool: + for fut in as_completed(...): + _LATENT_CACHE[name] = torch.from_numpy(np.load(f).copy()) +``` + +The custom `LatentDistributedSampler` then yields **only indices whose latent is actually in cache**, shards them evenly across ranks per epoch, and drops the remainder so every GPU sees an equal-sized slice. Result: **zero disk I/O inside the epoch**, GPUs at >95 % utilisation. + +### 4.4 Loss Weighting — Min-SNR γ + +Vanilla MSE treats every timestep equally. At high noise (large `t`), the signal in `ε` is weak and the gradients are noisy. Min-SNR (Hang et al., 2023) downweights those steps: + +```python +def _min_snr_weight(t, sched, gamma): + acp = sched.alphas_cumprod[t] + snr = acp / (1.0 - acp).clamp_min(1e-6) + return (snr.clamp(max=gamma) / snr).float() +``` + +I trained with **γ=5.0** for most phases. Dropping to γ=2.0 at epoch 16 produced my best loss ever (0.0947) — and silently broke face geometry. More on that in §7. γ=3.0 turned out to be the practical sweet spot. + +### 4.5 Classifier-Free Guidance Dropout + +For CFG to work at inference, the model has to know how to denoise *without* a prompt. During training I randomly replace the text embedding with a precomputed unconditional embedding: + +```python +empty_tok = tokenizer([""], padding="max_length", max_length=77, return_tensors="pt").to(device) +uncond_text_emb = model.encode_text(empty_tok.input_ids, empty_tok.attention_mask).squeeze(0).detach() +# ... reuse on every step +``` + +CFG dropout ramps from **0.05 (broad pretraining) → 0.15 (fine-tuning)**. + +### 4.6 EMA — The Secret Weapon + +The Exponential Moving Average of UNet weights is the single largest "free" quality win in the entire project. I keep it **GPU-resident** (no CPU round-trip per step) with decay 0.9999 and a warmup schedule: + +```python +d = min(decay, (1 + step) / (10 + step)) +self.shadow[n].lerp_(p.detach(), 1.0 - d) +``` + +The warmup formula prevents early noisy updates from dominating the shadow. At validation time, I swap live weights for the EMA shadow, sample, then restore — `ema.apply_shadow()` / `ema.restore()`. + +The visual gap between live and EMA weights is enormous: + +- **Live weights:** noisy, colour-leaky, jittery composition. +- **EMA weights:** stable, photorealistic, coherent. + +### 4.7 The Optimiser & Scheduler + +```python +optimizer = AdamW( + model.unet.parameters(), + lr=args.lr, + weight_decay=1e-2, + betas=(0.9, 0.999), + eps=1e-8, + fused=True, # single CUDA kernel; falls back if unavailable +) +``` + +LR schedule is **`SequentialLR(LinearLR warmup → CosineAnnealingLR)`** with `eta_min = lr * 1e-2`. Warmup is `min(args.warmup_steps, total_steps // 10)` so short fine-tunes don't get a 500-step warmup out of a 1500-step run. + +Gradients are clipped with `clip_grad_norm_(..., max_norm=1.0)` immediately before each optimiser step. + +### 4.8 Fault Tolerance + +I lost more time to "the pod died at 80 % of an epoch" than to anything else. Two patches: + +- **Atomic checkpoint writes.** Save to `*.tmp`, then `os.replace`. A killed process never leaves a half-written `.pt`. +- **`--save_steps`.** A mid-epoch step checkpoint (e.g. every 1500 global steps) lets a crash cost minutes instead of an hour. + +*(Insert Asset: Training step flow — Latent batch → BF16 forward → Min-SNR-weighted MSE → DDP all-reduce → AdamW step → EMA update.)* + +> ### Engineering Takeaways +> - **BF16 + Flash SDP + gradient checkpointing** is the right Blackwell triple. +> - **`no_sync()` on accumulation steps** halves your DDP traffic. +> - **Pre-encode and RAM-cache** to get a zero-I/O hot loop. +> - **EMA is not optional.** Decay 0.9999, GPU-resident, with warmup. +> - **Atomic saves + step checkpoints** are the cheapest insurance you'll ever buy. + +--- + +## 5. The Multi-Phase Journey — From Noise to Coherence + +Training was seven distinct phases, each with its own dataset, learning rate, and tactical goal. Total **48 epochs**, best loss **0.0947** (epoch 16), best images **epoch 42**. + +### Phase 1 — LAION Broad Pretraining (Epochs 1–10) + +- **Data:** 1,315,411 LAION images, aesthetic ≥ 6.5. +- **LR:** 1e-5 peak, 500-step warmup. +- **Epoch time:** ~3 hours. +- **Loss:** 0.220 → 0.1247. + +By **epoch 3 the outputs were literally multicolored static** — ghost-shapes if you squinted. I almost killed the run. By epoch 6 vague blobs started cohering; by epoch 10 the model could roughly distinguish "sunset" from "person." + +> **Footnote on the "crashes" at epochs 4 and 9:** these were **not** OOMs. They were deliberate `KeyboardInterrupt`s on my end (config tweaks). Peak reserved VRAM that whole phase was ~25.3 GB — well under the 33.7 GB headroom. The earlier draft of this post called them OOM crashes; that was wrong. + +### Phase 2 — LAION Rigorous Refinement (Epochs 11–17) + +- **Data:** 213,458 filtered LAION images, aesthetic ≥ 7.5, CLIP sim ≥ 0.30. +- **LR:** 1e-5 (fresh restart), 500-step warmup. +- **Epoch time:** ~30 minutes. +- **Loss:** 0.1260 → **0.0947 (epoch 16, best ever)** → 0.1083 (epoch 17). + +This phase produced the **single biggest visual quality jump** of the project. The lesson is uncomfortable: throwing away 84 % of your already-filtered data made the model dramatically better. + +At epoch 16 I tried Min-SNR γ=2.0 — got a beautiful loss number, watched faces start melting, reverted to γ=3.0 at epoch 17. (See §7.4.) + +Epochs 15 and 17 were stopped early via deliberate `KeyboardInterrupt`. The **epoch 15 EMA checkpoint is still my recommended inference base for pure image quality** before the synthetic-data domain shift. + +### Phase 3 — DiffusionDB + JourneyDB (Epochs 18–22) + +- **Data:** ~482k synthetic/curated images (DiffusionDB 500 shards + JourneyDB 10 archives → ~705k latents with mirroring). +- **LR:** 1e-5 (restart). +- **Epoch time:** ~1 hour. +- **Loss:** 0.0947 → 0.1207 → 0.1191. + +The loss **jumped** when I introduced the new mix. That's expected — the domain shifted from photographic LAION to a heavier synthetic distribution and the model had to re-calibrate. By epoch 22 it had stabilised. + +### Phase 4 — VGGFace2 Face Fine-Tuning (Epochs 23–29+) + +- **Data:** 51,786 VGGFace2 images @ 512×512 with templated captions ("photorealistic portrait of a person, soft studio lighting"). +- **LR:** 2e-6, 200-step warmup. +- **CFG dropout:** 0.05 → 0.15. +- **Epoch time:** ~2.3 hours. + +Face anatomy improved dramatically — bilateral eye symmetry, correct nose/mouth ratios, plausible skin texture. + +This phase needed a runtime fix: **`channels_last` briefly broke on sm_120** in my PyTorch 2.6+cu124 build. I switched to `contiguous_format` for the duration of Phase 4, then **reverted to `channels_last` once the environment stabilised** — that's what the current `SD_Train.py` ships with. (The earlier blog draft implied the switch was permanent; it wasn't.) + +### Phase 5 — COCO Full-Body Fine-Tuning (Epochs 30–38) + +- **Data:** COCO `detection-datasets`, filtered to person bbox ≥ 55 % image height → 59,494 images. +- **LR:** 1.5e-6. + +Background integration and body proportions improved. But **faces regressed** — the model overwrote some of Phase 4's gains with COCO's utilitarian framing. **This was my first hard lesson in catastrophic forgetting.** + +### Phase 6 — Mixed Consolidation (Epochs 39–42) + +- **Data:** LAION 150k + VGGFace2 50k + COCO 58k mixed 60/20/20. +- **LR:** 1e-6. + +Scene quality snapped back to near-perfect, face/body gains were preserved. At epoch 42 the model hit its visual "sweet spot." That checkpoint became **`sd_epoch_042.pt`** — the file I now use as my reference base. + +### Phase 7 — Final Comprehensive Consolidation (Epochs 43–48) + +- **Data:** LAION 213k + DiffusionDB/JourneyDB 250k + VGGFace2 51k + COCO 58k ≈ 572k total (37 / 44 / 9 / 10 %). +- **LR:** 1e-6. + +Epochs 43–44 came in at losses 0.1202 / 0.1193 — flat, healthy, no signs of divergence. Training stayed in progress while I battled pod-level CUDA restart issues on RunPod (cu124 vs cu13.2 image mismatches). + +*(Insert Asset: Validation time-lapse grid — epoch 1 → 10 → 17 → 22 → 42 → 48.)* + +### Loss History at a Glance + +| Epoch | Loss | Phase | Note | +|---|---|---|---| +| 1 | ~0.220 | P1 | Start | +| 2 | 0.1583 | P1 | First big drop | +| 7–8 | 0.1507 / 0.1508 | P1 | Plateau | +| 9–10 | 0.1249 / 0.1247 | P1 | End of broad | +| 11–14 | 0.126 → 0.1248 | P2 | Filtered LAION | +| 15 | 0.1026 | P2 | | +| **16** | **0.0947** | P2 | **Best ever (γ=2.0)** | +| 17 | 0.1083 | P2 | γ back to 3.0 | +| 18 | 0.1207 | P3 | Domain shift | +| 22 | 0.1191 | P3 | Stabilised | +| 42 | ~0.115 | P6 | Released: `sd_epoch_042.pt` | +| 43–44 | 0.1202 / 0.1193 | P7 | Final consolidation | + +> ### Engineering Takeaways +> - **Stop using the loss curve to decide when to stop.** Use a fixed-seed visual grid. +> - **The best loss number happened at epoch 16; the best images happened at epoch 42.** +> - **Sequential fine-tuning is a trap.** Mix datasets in every batch. +> - **Patience is a hyperparameter.** Diffusion models look catastrophic for the first 5–8 epochs. Trust the process. + +--- + +## 6. Inference — From Checkpoint to Canvas + +A `.pt` is not a picture. Several pieces have to slot together. + +### 6.1 DDIM Sampling + +Training uses 1000 DDPM steps; inference uses **DDIM** (deterministic, η=0). My rules of thumb: + +| Use case | Steps | +|---|---| +| Quick exploration | 25 | +| Production scenes | 50 | +| Faces, fine detail | **100** | + +Below 50 steps, faces look smudged. Above 100, diminishing returns. + +### 6.2 Classifier-Free Guidance + +Run the UNet twice each step — once with the prompt embedding, once with the unconditional/empty embedding — then combine: + +```python +guided = noise_uncond + guidance_scale * (noise_cond - noise_uncond) +latents = scheduler.step(guided, t, latents) +``` + +| Subject | CFG scale | +|---|---| +| Scenes / landscapes | 7.5 | +| Portraits | 8.5 | +| Anything | **stop at 9.0** | + +Higher than 9.0 introduces oversaturation and CFG artefacts (waxy skin, blown highlights). + +### 6.3 Negative Prompts + +Negative prompts replace the unconditional embedding with something like `"blurry, low quality, distorted, deformed"`. They're the fastest way to prune common generative failure modes. + +**Implementation note for the repo:** I have two inference scripts. + +- `SD_ImageGen.py` (CUDA) — properly wires negative prompts through `generate(..., negative_prompts=...)`. +- `inference.py` (CUDA/MPS, Apple Silicon friendly) — currently parses `--negative` but doesn't thread it into `generate()`; the unconditional branch still uses the empty string. **If you want true negative-prompt CFG, use `SD_ImageGen.py` until I fix that.** + +### 6.4 Live vs EMA — The A/B Test That Mattered Most + +Same prompt, same seed, same steps, same CFG, same UNet weights but different snapshot: + +| Snapshot | Result | +|---|---| +| Live UNet weights | Noisy, jittery composition, "colour leakage" between subject and background | +| EMA shadow (decay 0.9999) | Stable, photorealistic, coherent | + +**EMA at inference is non-negotiable.** + +> **Honest caveat about the saved `val_epoch_*.png` grids.** My current `SD_Train.py:validate()` does call `ema.apply_shadow()` before sampling — so going forward, validation grids reflect EMA quality. But the grids saved earlier in the run (everything you see in `sd-val-imgs/`) were produced before that fix was wired in and reflect *live* weights. Mentally add ~15 % of perceived quality when looking at them. + +*(Insert Asset: Side-by-side — live-weight portrait vs EMA portrait, identical seed.)* + +--- + +## 7. Pitfalls — The Hard-Won Lessons + +The most valuable part of any project is the failures. + +### 7.1 The Latent Clamp Bug — Two Days Lost to One Line + +Training looked great. Inference output was greyish, washed-out, incoherent. I spent two days convinced the UNet was broken before I found this line in `DDIMScheduler.step()`: + +```python +pred_x0 = pred_x0.clamp(-1.0, 1.0) # ← THIS +``` + +It looks reasonable. It is catastrophically wrong. **SD latents are not in `[-1, 1]`** — their standard deviation is ≈ 4.0. Clamping decapitates the signal. + +I removed it at inference and the images snapped into focus. + +```python +# in inference.py: monkey-patch the scheduler before sampling +SD_Model.DDIMScheduler.step = _fixed_ddim_step # the no-clamp version +``` + +**Caveat:** the clamp is still present in `SD_Model.py:736` because removing it changes behaviour for any old script that imports the class directly. The inference scripts patch it out at runtime. The proper long-term fix is to delete that line. + +**Lesson:** Never assume latent distributions match pixel distributions. + +### 7.2 `torch.compile` × Gradient Checkpointing — Two Great Tastes That Don't Mix + +I tried wrapping the UNet in `torch.compile` to squeeze more speed. It worked — until I enabled gradient checkpointing. Immediate `AssertionError` deep inside Dynamo. + +The reason: `checkpoint(use_reentrant=False)` re-runs the forward pass inside a dynamo-disabled context during backward. The compiled wrapper sees a forward whose Dynamo state has been pulled out from under it. + +I stuck with eager mode. With BF16 + Flash SDP the throughput penalty was negligible. `torch.compile` is still great outside the loop — I used it on the VAE during latent encoding (§3.3). + +### 7.3 `channels_last` on sm_120 + +When I first enabled `torch.channels_last` for the UNet on the 5090s, my PyTorch 2.6+cu124 build threw shape-mismatch errors inside cuDNN. The fix that worked for Phase 4 was switching to `contiguous_format`. After a CUDA / driver update I switched back to `channels_last` and it now runs cleanly — the current `SD_Train.py` ships with `channels_last`. + +**Lesson:** Optimisations that "everyone uses" can break on bleeding-edge hardware. Have a fallback ready. + +### 7.4 Min-SNR γ = 2.0 — The Beautiful Wrong Number + +Chasing better gradients, I dropped Min-SNR γ from 5.0 to 2.0 in Phase 2. The loss looked gorgeous — **0.0947, my best ever** — but the *images* started melting. Faces lost geometry because high-noise timesteps were being over-weighted relative to the low-noise timesteps that actually carry facial detail. + +I reverted to γ=3.0 and the metric got worse while the images got better. + +**Lesson:** lower MSE ≠ better images. Visual validation is the only ground truth. + +### 7.5 The Patience Problem + +In the first three epochs I almost gave up. The loss looked fine but the validation grids were noise. I kept the fixed-seed grids anyway. By epoch 5 vague shapes appeared. By epoch 8 they became things. + +**Lesson:** the "aha" moment in diffusion happens late. Don't kill a run because the first few epochs look like television static. + +*(Insert Asset: Side-by-side — "broken" clamp-bug output vs "fixed" output after removing the clamp.)* + +--- + +## 8. What I'd Do Differently + +If I started over tomorrow: + +1. **Mixed batching from day one.** I wasted weeks in a sequential mindset. Catastrophic forgetting between Phases 4 and 5 cost me a week. Always sample across datasets in every batch. +2. **Start with a lower Min-SNR γ (2.0–2.5) carefully.** For aesthetic-heavy data, lower γ can speed convergence — but test it visually, not just by loss. +3. **Test inference every single epoch.** I'd have caught the latent clamp bug in week one instead of week six. +4. **Validate with EMA from the very first epoch.** Don't ship a fix you forgot to backfill into your visualisations. +5. **Delete the training-time `pred_x0.clamp`** instead of monkey-patching around it. Make the wrong version unreachable. +6. **Wire `--negative` through `inference.py:generate()` properly** so both inference scripts behave identically. + +--- + +## 9. What's Next + +`sd_epoch_042.pt` is a solid base. The roadmap from here: + +1. **Rectified flow fine-tuning.** Current frontier for diffusion image quality and the natural successor to ε-prediction. +2. **LCM (Latent Consistency Model) distillation.** Enables 1–4 step generation — real-time inference. +3. **ControlNet.** Spatial conditioning — pose, depth, edge maps — without retraining the base. +4. **SD_Model_v2.** Already designed: MM-DiT backbone, dual CLIP-L + OpenCLIP-bigG text encoders, native rectified flow. This is where I want to live next. + +--- + +## 10. The Repo at a Glance + +``` +StableDiffusion/ +├── SD_Model.py # UNet + VAE/CLIP wrappers + DDPM/DDIM schedulers +├── SD_Train.py # 2× RTX 5090 DDP + BF16 training loop +├── SD_ImageGen.py # CUDA inference (full negative-prompt CFG) +├── inference.py # CUDA/MPS inference (DDIM clamp monkey-patched) +├── encode_latents.py # 4-stage VAE → fp16 .npy pipeline +├── 01_download_metadata.py # LAION parquet snapshot +├── 01b_download_diffusiondb.py # DiffusionDB → 512×512 tar shards +├── 01c_download_journeydb_images.py # JourneyDB → 512×512 tar shards +├── 02_filter_metadata.py # aesthetic / CLIP / watermark / NSFW / dedup +├── 03_download_images.py # img2dataset LAION downloader +├── 03_build_hf_dataset.py # DiffusionDB/JourneyDB → Arrow HF dataset +├── 04_preprocess_to_cache.py # Tars → parquet (image_key + CLIP tokens) +├── 05_build_hf_dataset.py # Parquet → Arrow HF dataset +├── sd_epoch_042.pt # Released checkpoint (~12.5 GB) +├── sd-val-imgs/ # val_epoch_001..043.png (live-weight grids) +├── sd-logs/ # captured training.log / output*.log +└── generated_images/ # curated epoch-42 renders +``` + +--- + +## Final Thoughts + +Training a Stable Diffusion model from scratch was one of the most rewarding engineering projects I've worked on — and the "AI" was the easy part. The real difficulty is in: + +- treating data quality as a hyperparameter, +- moving bytes between disk, host RAM, pinned buffers and VRAM without ever blocking the GPU, +- knowing which optimisations compose and which ones explode when you stack them, +- and trusting visual validation over a loss number that lies to you. + +If you're considering this: do it. But go in with your eyes open. The transformers will be fine. It's the JPEG decoder, the pinned buffer, the EMA decay and the one clamp at line 736 that will decide whether your model converges. + +--- + +## Resources + +- **Core frameworks:** [PyTorch](https://pytorch.org), [Hugging Face Diffusers](https://github.com/huggingface/diffusers), [Transformers](https://github.com/huggingface/transformers) +- **Models & data:** [LAION](https://laion.ai), [DiffusionDB](https://github.com/poloclub/diffusiondb), [JourneyDB](https://huggingface.co/datasets/JourneyDB/JourneyDB), [VGGFace2](https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/), [COCO](https://cocodataset.org/) +- **Performance:** [Flash Attention](https://github.com/Dao-AILab/flash-attention), [NCCL](https://developer.nvidia.com/nccl) +- **Reference papers:** Ho et al. 2020 (DDPM), Song et al. 2020 (DDIM), Hang et al. 2023 (Min-SNR), Rombach et al. 2022 (Latent Diffusion) + +**Tags:** #MachineLearning #DeepLearning #StableDiffusion #PyTorch #GenerativeAI #GPUComputing #AITraining #RTX5090 #Blackwell diff --git a/docs/data-pipeline.md b/docs/data-pipeline.md new file mode 100644 index 0000000..49918e7 --- /dev/null +++ b/docs/data-pipeline.md @@ -0,0 +1,134 @@ +# Data Pipeline + +The data pipeline processes LAION-2B-en-aesthetic (~12 million images filtered to aesthetic score ≥ 6.5) into a format suitable for training. Each script is a self-contained step. + +## Pipeline Steps + +### Step 1: Download Metadata + +```bash +python data_pipeline/01_download_metadata.py +``` + +Downloads LAION-2B-en-aesthetic parquet shards from the Hugging Face Hub. The dataset contains ~230 million image-URL–caption pairs with precomputed aesthetic scores, CLIP similarity scores, and watermark probabilities. + +**Output:** `laion_metadata/*.parquet` (multiple shards) + +### Step 1b-c: Alternative Sources (Optional) + +```bash +python data_pipeline/01b_download_diffusiondb.py +python data_pipeline/01c_download_journeydb_images.py +``` + +Alternative or supplementary datasets — not used in the final training run. + +### Step 2: Filter Metadata + +```bash +python data_pipeline/02_filter_metadata.py +``` + +Applies quality filters to the raw metadata: + +| Filter | Threshold | Rationale | +|---|---|---| +| Aesthetic score | ≥ 6.5 | Top ~5% of LAION-2B, high visual quality | +| CLIP similarity | ≥ 0.28 | Text–image alignment | +| Resolution | ≥ 512 px | Minimum for 512×512 training | +| Watermark probability | ≤ 0.5 | Reduce watermarked images | +| NSFW | False | Safety filter | + +**Output:** `laion_filtered/*.parquet` (~12M entries, ~9 GB) + +### Step 3: Download Images + +```bash +python data_pipeline/03_download_images.py +``` + +Downloads images from the filtered URLs using `img2dataset`. Handles: +- HTTP timeouts and retries +- Image format validation (JPEG, PNG) +- Resolution checks (skip images < 512 px in either dimension) +- WebDataset tar packaging for efficient I/O + +**Output:** `laion_shards/` — WebDataset tar files (~6 TB raw images) + +### Step 4: Preprocess to Cache + +```bash +python data_pipeline/04_preprocess_to_cache.py +``` + +Tokenizes captions with CLIP tokenizer and applies basic augmentations: + +- BPE tokenization (77 tokens, truncation/padding) +- CLIP attention mask computation +- Image resizing to 512×512 (center crop if necessary) + +**Output:** `laion_cache/` — Tokenized cache shards + +### Step 5: Build Hugging Face Dataset + +```bash +python data_pipeline/05_build_hf_dataset.py +``` + +Assembles tokenized data and raw image paths into a Hugging Face `datasets.Dataset` (memory-mapped Arrow format). Splits into train/validation (5K val samples). + +**Output:** `laion_hf_dataset/` (HF Dataset on disk, ~50 GB) + +### Step 5b: Fine-tuning Subset + +```bash +python data_pipeline/05_build_hf_dataset.py --output laion_hf_dataset_ft +``` + +Builds a finer-quality subset for fine-tuning — stricter aesthetic thresholds (≥ 7.0) and additional deduplication. + +### Step 6: Filter Dataset (Optional) + +```bash +python data_pipeline/06_filter_dataset.py +``` + +Additional filtering after dataset construction (e.g., CLIP score re-ranking, near-dedup). + +## VAE Pre-encoding + +Latent pre-encoding is the last preprocessing step before training. It converts all images in the dataset to VAE latents, which are saved as `.npy` files. During training, the DataLoader reads pre-encoded latents directly — no VAE forward pass needed. + +```bash +# Single GPU encoding +python src/encode_latents.py + +# Dual-GPU parallel encoding +python src/encode_pipeline.py +``` + +**Output:** `laion_latents/*.npy` (41.4 GB total, 4×64×64 per image) + +### Why Pre-encode? + +1. **VRAM savings:** The VAE forward pass requires ~2 GB per image at 512×512. Pre-encoding frees this memory for the UNet. +2. **Speed:** Latent loading from `.npy` is ~10× faster than VAE encoding per image. +3. **Dataset augmentation:** Pre-encoding is deterministic — all training epochs see identical latents from the same images, which is actually desired for reproducibility. + +## Production Training Pipeline + +For training, the DataLoader reads from two sources simultaneously: + +1. **Latents:** Memory-mapped from `laion_latents/*.npy` via `numpy.load` + `torch.from_numpy` +2. **Tokenized text:** From the HF Dataset `input_ids` and `attention_mask` columns + +```python +class LatentDataset(Dataset): + def __getitem__(self, idx): + latent = np.load(self.latent_files[idx]) # (4, 64, 64) float32 + ids = self.dataset[idx]["input_ids"] # (77,) int64 + mask = self.dataset[idx]["attention_mask"] # (77,) int64 + return {"pixel_values": latent, "input_ids": ids, "attention_mask": mask} +``` + +The DataLoader uses `pin_memory=True`, `prefetch_factor=4`, and `num_workers=16` (per GPU) to saturate the 2× RTX 5090 setup. diff --git a/docs/images/architecture/architecture_overview.png b/docs/images/architecture/architecture_overview.png new file mode 120000 index 0000000..08e9efc --- /dev/null +++ b/docs/images/architecture/architecture_overview.png @@ -0,0 +1 @@ +../../../assets/architecture_overview.png \ No newline at end of file diff --git a/docs/images/hero/collage.png b/docs/images/hero/collage.png new file mode 100644 index 0000000..e22773a Binary files /dev/null and b/docs/images/hero/collage.png differ diff --git a/docs/images/samples/car.png b/docs/images/samples/car.png new file mode 120000 index 0000000..2c7e804 --- /dev/null +++ b/docs/images/samples/car.png @@ -0,0 +1 @@ +../../../results/samples/car.png \ No newline at end of file diff --git a/docs/images/samples/cinematic.png b/docs/images/samples/cinematic.png new file mode 120000 index 0000000..61cb654 --- /dev/null +++ b/docs/images/samples/cinematic.png @@ -0,0 +1 @@ +../../../results/samples/cinematic.png \ No newline at end of file diff --git a/docs/images/samples/city.png b/docs/images/samples/city.png new file mode 120000 index 0000000..9fd86b2 --- /dev/null +++ b/docs/images/samples/city.png @@ -0,0 +1 @@ +../../../results/samples/city.png \ No newline at end of file diff --git a/docs/images/samples/city_42.png b/docs/images/samples/city_42.png new file mode 120000 index 0000000..8ac661e --- /dev/null +++ b/docs/images/samples/city_42.png @@ -0,0 +1 @@ +../../../results/samples/city_42.png \ No newline at end of file diff --git a/docs/images/samples/custom.png b/docs/images/samples/custom.png new file mode 120000 index 0000000..f005091 --- /dev/null +++ b/docs/images/samples/custom.png @@ -0,0 +1 @@ +../../../results/samples/custom.png \ No newline at end of file diff --git a/docs/images/samples/forest.png b/docs/images/samples/forest.png new file mode 120000 index 0000000..4fa5bff --- /dev/null +++ b/docs/images/samples/forest.png @@ -0,0 +1 @@ +../../../results/samples/forest.png \ No newline at end of file diff --git a/docs/images/samples/forest_42.png b/docs/images/samples/forest_42.png new file mode 120000 index 0000000..f918c9c --- /dev/null +++ b/docs/images/samples/forest_42.png @@ -0,0 +1 @@ +../../../results/samples/forest_42.png \ No newline at end of file diff --git a/docs/images/samples/fullbody.png b/docs/images/samples/fullbody.png new file mode 120000 index 0000000..3db9fa6 --- /dev/null +++ b/docs/images/samples/fullbody.png @@ -0,0 +1 @@ +../../../results/samples/fullbody.png \ No newline at end of file diff --git a/docs/images/samples/fullbody_2.png b/docs/images/samples/fullbody_2.png new file mode 120000 index 0000000..b5bb262 --- /dev/null +++ b/docs/images/samples/fullbody_2.png @@ -0,0 +1 @@ +../../../results/samples/fullbody_2.png \ No newline at end of file diff --git a/docs/images/samples/fullbody_42.png b/docs/images/samples/fullbody_42.png new file mode 120000 index 0000000..9fb2a65 --- /dev/null +++ b/docs/images/samples/fullbody_42.png @@ -0,0 +1 @@ +../../../results/samples/fullbody_42.png \ No newline at end of file diff --git a/docs/images/samples/landscape.png b/docs/images/samples/landscape.png new file mode 120000 index 0000000..ae90d81 --- /dev/null +++ b/docs/images/samples/landscape.png @@ -0,0 +1 @@ +../../../results/samples/landscape.png \ No newline at end of file diff --git a/docs/images/samples/landscape_42.png b/docs/images/samples/landscape_42.png new file mode 120000 index 0000000..5910a3a --- /dev/null +++ b/docs/images/samples/landscape_42.png @@ -0,0 +1 @@ +../../../results/samples/landscape_42.png \ No newline at end of file diff --git a/docs/images/samples/man.png b/docs/images/samples/man.png new file mode 120000 index 0000000..d0c1694 --- /dev/null +++ b/docs/images/samples/man.png @@ -0,0 +1 @@ +../../../results/samples/man.png \ No newline at end of file diff --git a/docs/images/samples/portrait.png b/docs/images/samples/portrait.png new file mode 120000 index 0000000..89224e6 --- /dev/null +++ b/docs/images/samples/portrait.png @@ -0,0 +1 @@ +../../../results/samples/portrait.png \ No newline at end of file diff --git a/docs/images/samples/portrait_42.png b/docs/images/samples/portrait_42.png new file mode 120000 index 0000000..29b5f46 --- /dev/null +++ b/docs/images/samples/portrait_42.png @@ -0,0 +1 @@ +../../../results/samples/portrait_42.png \ No newline at end of file diff --git a/docs/images/samples/portrait_centered.png b/docs/images/samples/portrait_centered.png new file mode 120000 index 0000000..0157b19 --- /dev/null +++ b/docs/images/samples/portrait_centered.png @@ -0,0 +1 @@ +../../../results/samples/portrait_centered.png \ No newline at end of file diff --git a/docs/images/samples/portrait_centered_2.png b/docs/images/samples/portrait_centered_2.png new file mode 120000 index 0000000..ef97896 --- /dev/null +++ b/docs/images/samples/portrait_centered_2.png @@ -0,0 +1 @@ +../../../results/samples/portrait_centered_2.png \ No newline at end of file diff --git a/docs/images/samples/portrait_centered_3.png b/docs/images/samples/portrait_centered_3.png new file mode 120000 index 0000000..341e927 --- /dev/null +++ b/docs/images/samples/portrait_centered_3.png @@ -0,0 +1 @@ +../../../results/samples/portrait_centered_3.png \ No newline at end of file diff --git a/docs/images/samples/portrait_highcfg.png b/docs/images/samples/portrait_highcfg.png new file mode 120000 index 0000000..df93e09 --- /dev/null +++ b/docs/images/samples/portrait_highcfg.png @@ -0,0 +1 @@ +../../../results/samples/portrait_highcfg.png \ No newline at end of file diff --git a/docs/inference.md b/docs/inference.md new file mode 100644 index 0000000..c84577c --- /dev/null +++ b/docs/inference.md @@ -0,0 +1,135 @@ +# Inference Guide + +## Quick Start + +```bash +# Download checkpoint first +python scripts/download_checkpoint.py + +# Basic inference +python src/inference.py \ + --prompt "a beautiful sunset over mountain peaks" \ + --checkpoint checkpoints/sd_epoch_042.pt + +# Apple Silicon (MPS) — auto-detected +python src/inference.py \ + --prompt "a cat wearing a spacesuit" \ + --checkpoint checkpoints/sd_epoch_042.pt +``` + +## Usage Options + +### Basic Parameters + +| Flag | Default | Description | +|---|---|---| +| `--prompt` | (required) | Text prompt | +| `--batch` | — | `.txt` file with one prompt per line (overrides `--prompt`) | +| `--negative` | `""` | Negative prompt for CFG | +| `--checkpoint` | `sd_epoch_042.pt` | Path to checkpoint | +| `--steps` | 50 | DDIM steps: 25 (fast), 50 (good), 100 (best) | +| `--guidance` | 7.5 | CFG scale: 1.0 (no guidance), 7.5 (balanced), 15+ (high) | +| `--seed` | 42 | Random seed for reproducibility | +| `--width` / `--height` | 512 | Output dimensions (must be multiples of 8) | +| `--batch_size` | 1 | Images per generation call (for batch prompts) | +| `--output` | `output.png` | Output filename (single prompt) | +| `--output_dir` | `./outputs` | Output directory (batch mode) | + +### Examples + +```bash +# High quality, deterministic +python src/inference.py \ + --prompt "a cinematic shot of a mountain lake at sunrise, professional photography" \ + --steps 100 --guidance 7.5 --seed 42 + +# With negative prompt +python src/inference.py \ + --prompt "a portrait of a woman" \ + --negative "blurry, low quality, deformed hands, extra fingers" \ + --steps 50 --guidance 9.0 + +# Batch mode — generate from 100 prompts +python src/inference.py \ + --batch prompts.txt \ + --output_dir ./generated \ + --steps 50 --guidance 7.5 + +# Stochastic sampling (DDPM-like, more variety) +python src/inference.py \ + --prompt "fantasy landscape" \ + --eta 1.0 --seed 999 +``` + +## Scripts + +### `src/inference.py` — Apple Silicon + CUDA + +Designed for MacBook (MPS) and single-GPU inference. Automatically detects MPS, CUDA, or CPU. + +Features: +- Automatic device selection +- Negative prompt support +- Batch processing from text file +- DDIM with optional stochastic (eta) +- EMA weight loading from checkpoint + +### `src/SD_ImageGen.py` — Alternative CLI + +Full-featured CLI with additional options: +- Supports both raw UNet weights and EMA shadow weights +- Negative prompt per-sample broadcasting +- autocast BF16 on CUDA +- Image grid generation for multiple outputs + +### `src/generate.py` — Programmatic API + +```python +from generate import generate_images + +images = generate_images( + prompts=["a cosmic nebula with vibrant colors"], + checkpoint_path="checkpoints/sd_epoch_042.pt", + num_steps=50, + guidance_scale=7.5, + seed=42, + device="cuda", # or "mps", "cpu" +) +images[0].save("nebula.png") +``` + +## DDIM Parameters + +### Steps + +| Steps | Quality | Speed | +|---|---|---| +| 25 | Good | 2× faster | +| 50 | Recommended | Baseline | +| 100 | Excellent | 2× slower | +| 200+ | Diminishing | Not worth it | + +### Eta (Stochasticity) + +| Eta | Behavior | +|---|---| +| 0.0 | Deterministic — same seed always produces the same image | +| 0.5 | Moderate stochasticity — small variations | +| 1.0 | DDPM-like — maximum variety, but may lose fidelity | + +### CFG Scale + +| Scale | Effect | +|---|---| +| 1.0 | No guidance — pure model prior, often blurry/unrelated | +| 5.0–7.5 | Balanced — recommended range | +| 9.0–12.0 | Strong guidance — more prompt alignment, may oversaturate | +| 15.0+ | Excessive — often produces artifacts, burned-in look | + +## Output Quality Tips + +1. **Use descriptive prompts:** "a cinematic shot of..." works better than "a photo of..." +2. **Negative prompts help:** Common negatives: "blurry, low quality, deformed, extra limbs, bad anatomy, ugly, text, watermark" +3. **Seed selection:** For a given prompt, try seeds 0–20 and pick the best +4. **Steps vs. quality:** 50 steps is usually sufficient; 100+ gives marginal gains +5. **CFG tuning:** Start at 7.5, adjust ±2 based on output character diff --git a/docs/training-loop.md b/docs/training-loop.md new file mode 100644 index 0000000..a4d4ecc --- /dev/null +++ b/docs/training-loop.md @@ -0,0 +1,113 @@ +# Training Loop + +## Optimization Stack + +The training loop in `src/train.py` is optimized for dual RTX 5090 (Blackwell) GPUs: + +1. **DDP (DistributedDataParallel):** True multi-GPU with NCCL backend. Each GPU processes its own micro-batch; gradients are all-reduced across both GPUs. +2. **BF16 native (torch.autocast):** Blackwell has native BF16 support — no GradScaler needed. +3. **Flash Attention (FlashSDP):** Enabled on Blackwell (cc ≥ 8.0), ~2–4× faster attention with O(N) memory vs O(N²). +4. **Gradient checkpointing:** Saves ~40% activation memory by recomputing activations on backward. +5. **Fused AdamW:** `torch.optim.AdamW` with `fused=True` — single kernel for the entire optimizer step. +6. **`channels_last` memory format:** Optimal for convolutions on Blackwell (configurable via `--memory_format`). +7. **EMA on GPU:** 32 GB VRAM per GPU is plenty — no CPU round-trips for the EMA shadow copy. +8. **Min-SNR loss weighting:** Better gradient signal across timesteps (Hang et al., 2023). + +## Training Procedure + +### Epoch Structure + +``` +for each epoch: + sampler.set_epoch(epoch) # shuffle DDP sampler + for each batch: + latents = batch["pixel_values"] # pre-encoded VAE latents from disk + ids, mask = batch["input_ids"] # tokenized text + t = sample_timesteps(batch) # uniform [0, 1000) + noise = randn_like(latents) + z_t = noise_scheduler.add_noise(latents, noise, t) + + # CFG dropout: randomly replace conditioning with empty text + if random() < cfg_dropout: + ids, mask = empty_tokens() + + with autocast("cuda", bf16): + noise_pred = ddp_unet(z_t, t, context) + loss = min_snr_weighted_mse(noise_pred, noise, t) + + loss.backward() + optimizer.step() + lr_scheduler.step() + ema.update(unet_raw) +``` + +### Loss Function + +Standard MSE loss between predicted and actual noise, weighted by Min-SNR (Hang et al., 2023): + +```python +def min_snr_weighted_mse(noise_pred, noise, t, gamma=5.0): + snr = alphas_cumprod[t] / (1 - alphas_cumprod[t]) + weights = torch.clamp(snr, max=gamma) + loss = F.mse_loss(noise_pred, noise, reduction="none") + return (loss * weights.view(-1, 1, 1, 1)).mean() +``` + +### Learning Rate Schedule + +1. **Linear warmup** for 500 steps from 0 → peak LR +2. **Cosine annealing** from peak LR → 0 over remaining steps +3. **Peak LR:** 1e-4 for pre-training, 1e-5 for fine-tuning + +### Hyperparameters + +| Parameter | Pre-training (ep 1–10) | Fine-tuning (ep 11–42) | +|---|---|---| +| Batch size per GPU | 24 | 24 | +| Gradient accumulation | 2 | 2 | +| Effective batch | 96 | 96 | +| Peak LR | 1e-4 (ep 1), 1e-5 (ep 2–10) | 1e-5 | +| Min-SNR γ | 5.0 | 2.5 | +| CFG dropout | 0.05 | 0.05 | +| Weight decay | 0.01 | 0.01 | +| EMA decay | 0.9999 | 0.9999 | +| Warmup steps | 500 | 500 | + +## EMA (Exponential Moving Average) + +Polyak-style EMA maintains a shadow copy of all UNet parameters: + +```python +shadow[n].lerp_(p.detach(), 1.0 - decay) +``` + +The shadow weights are updated after every optimizer step with an effective decay that increases to 0.9999 over the first 10 steps: + +```python +d = min(decay, (1 + step) / (10 + step)) +``` + +At evaluation and checkpoint time, EMA weights are swapped in for inference (producing noticeably better samples than the live weights). + +## Checkpoint Format + +Each checkpoint saved to `checkpoints/sd_epoch_NNN.pt` contains: + +```python +{ + "unet_state_dict": ..., # Raw UNet weights (for strict loading) + "ema_state_dict": { # EMA shadow copy + "shadow": {...}, # Parameter name → tensor mapping (with prefix stripping) + "step_count": 232235, + "decay": 0.9999, + }, + "optimizer_state_dict": ..., # Full AdamW state (for resume) + "lr_scheduler_state_dict": ..., + "epoch": 42, + "global_step": 232235, + "best_loss": 0.0947, + "config": {...}, # Training configuration snapshot +} +``` + +The checkpoint is ~12.5 GB (BF16 weights are stored as FP32 for CPU loading stability). diff --git a/scripts/download_checkpoint.py b/scripts/download_checkpoint.py new file mode 100644 index 0000000..2c04294 --- /dev/null +++ b/scripts/download_checkpoint.py @@ -0,0 +1,58 @@ +""" +Download the released checkpoint from Hugging Face Hub. + +Usage: + python scripts/download_checkpoint.py + python scripts/download_checkpoint.py --output checkpoints/sd_epoch_042.pt +""" + +import argparse +import sys +from pathlib import Path + +try: + from huggingface_hub import hf_hub_download, login +except ImportError: + print("Error: huggingface_hub not installed. Run: pip install huggingface_hub") + sys.exit(1) + +REPO_ID = "atandra2000/sd-from-scratch-v1" +FILENAME = "sd_epoch_042.pt" + +def parse_args(): + p = argparse.ArgumentParser(description="Download SD-From-Scratch checkpoint") + p.add_argument("--output", type=str, default="checkpoints/sd_epoch_042.pt", + help="Output path for the checkpoint") + p.add_argument("--token", type=str, default=None, + help="HF Hub token (optional for public repos)") + return p.parse_args() + +def main(): + args = parse_args() + + output = Path(args.output) + output.parent.mkdir(parents=True, exist_ok=True) + + if output.exists(): + size_gb = output.stat().st_size / (1024**3) + print(f"Already exists: {output} ({size_gb:.1f} GB)") + return + + print(f"Downloading {FILENAME} from {REPO_ID}...") + print(f"Size: ~12.5 GB — this will take a while.") + + if args.token: + login(token=args.token) + + path = hf_hub_download( + repo_id=REPO_ID, + filename=FILENAME, + local_dir=output.parent, + local_dir_use_symlinks=False, + ) + + size_gb = Path(path).stat().st_size / (1024**3) + print(f"Downloaded: {path} ({size_gb:.1f} GB)") + +if __name__ == "__main__": + main() diff --git a/src/SD_ImageGen.py b/src/SD_ImageGen.py index 6f9647c..c2038e9 100644 --- a/src/SD_ImageGen.py +++ b/src/SD_ImageGen.py @@ -82,7 +82,7 @@ def load_model(checkpoint_path: str, device: torch.device) -> StableDiffusionMod ema = EMA(model.unet, decay=0.9999) ema.load_state_dict(ckpt["ema_state_dict"]) # Apply EMA shadow weights to the model - ema.apply_shadow() + ema.apply_shadow(model.unet) logger.info(f"✅ EMA weights applied (step {ema.step_count}) — better image quality") elif "model_state_dict" in ckpt: diff --git a/src/SD_Model.py b/src/SD_Model.py index 5ee4312..c6debeb 100644 --- a/src/SD_Model.py +++ b/src/SD_Model.py @@ -660,12 +660,14 @@ class DDIMScheduler: def __init__( self, - steps: int = 1000, - beta_start: float = 0.00085, - beta_end: float = 0.012, - schedule: str = "scaled_linear", + steps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.012, + schedule: str = "scaled_linear", + clamp_pred_x0: bool = False, ): self.num_train_timesteps = steps + self.clamp_pred_x0 = clamp_pred_x0 if schedule == "scaled_linear": betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, steps) ** 2 @@ -678,10 +680,6 @@ def __init__( self.timesteps: Optional[torch.Tensor] = None self.num_inference_steps: Optional[int] = None - def to(self, device: torch.device): - self.alphas_cumprod = self.alphas_cumprod.to(device) - return self - def set_timesteps(self, num_steps: int, device: torch.device): """ Compute the subset of evenly-spaced timesteps used for inference. @@ -733,7 +731,8 @@ def step( # Step 1: Estimate clean latent x̂_0 from noisy x_t pred_x0 = (x_t - (1.0 - alpha_t).sqrt() * noise_pred) / alpha_t.sqrt() - pred_x0 = pred_x0.clamp(-1.0, 1.0) # prevent extreme values propagating + if self.clamp_pred_x0: + pred_x0 = pred_x0.clamp(-1.0, 1.0) # Step 2: Direction from x̂_0 towards x_t (diffusion "velocity") dir_xt = (1.0 - alpha_prev).sqrt() * noise_pred diff --git a/src/SD_Train.py b/src/SD_Train.py index 69d19c4..4fe5f38 100644 --- a/src/SD_Train.py +++ b/src/SD_Train.py @@ -324,6 +324,7 @@ def train_epoch( uncond_text_emb=None, use_wandb=False, use_min_snr=True, min_snr_gamma=5.0, cfg_dropout=0.0, save_steps=0, ckpt_dir="checkpoints", best_loss=float("inf"), + memory_format: str = "channels_last", ) -> tuple[float, int]: ddp_unet.train() model.text_encoder.eval() @@ -338,7 +339,8 @@ def train_epoch( continue try: latents = batch["pixel_values"].to(device, dtype=torch.bfloat16, non_blocking=True) - latents = latents.contiguous(memory_format=torch.channels_last) + if memory_format == "channels_last": + latents = latents.contiguous(memory_format=torch.channels_last) ids = batch["input_ids"].to(device, non_blocking=True) mask = batch["attention_mask"].to(device, non_blocking=True) @@ -603,7 +605,8 @@ def main(rank, world_size, args): noise_scheduler = DDPMScheduler(steps=1000, beta_start=0.00085, beta_end=0.012, schedule="scaled_linear") model = StableDiffusionModel(vae, text_enc, unet, noise_scheduler).to(device) - model.unet = model.unet.to(memory_format=torch.channels_last) + if args.memory_format == "channels_last": + model.unet = model.unet.to(memory_format=torch.channels_last) noise_scheduler.to(device) # ── Dataset + unconditional embedding (precomputed once) ────────────────── @@ -681,6 +684,7 @@ def main(rank, world_size, args): use_wandb=args.use_wandb, use_min_snr=args.min_snr, min_snr_gamma=args.min_snr_gamma, cfg_dropout=args.cfg_dropout, save_steps=args.save_steps, ckpt_dir=args.ckpt_dir, best_loss=best_loss, + memory_format=args.memory_format, ) if avg_loss < best_loss: @@ -720,7 +724,7 @@ def main(rank, world_size, args): # Data parser.add_argument("--cache_path", type=str, default="laion_hf_dataset/train") - parser.add_argument("--latent_dir", type=str, default="laion_latents/laion_latents") + parser.add_argument("--latent_dir", type=str, default="laion_latents") parser.add_argument("--val_size", type=int, default=500) # Model @@ -743,6 +747,10 @@ def main(rank, world_size, args): parser.add_argument("--no-min-snr", dest="min_snr", action="store_false") parser.add_argument("--min_snr_gamma", type=float, default=5.0) parser.add_argument("--cfg_dropout", type=float, default=0.05, help="CFG dropout probability.") + parser.add_argument("--memory_format", type=str, default="channels_last", + choices=("channels_last", "contiguous"), + help="Memory format for UNet and latents. channels_last speeds up convs on NVIDIA GPUs. " + "Use 'contiguous' for AMD or Apple Silicon.") # Checkpointing parser.add_argument("--save_every", type=int, default=1, help="Save checkpoint every N epochs.") diff --git a/src/SD_Train_v2.py b/src/SD_Train_v2.py index ab82006..6cd8236 100644 --- a/src/SD_Train_v2.py +++ b/src/SD_Train_v2.py @@ -1485,7 +1485,7 @@ def main(rank: int, world_size: int, args: argparse.Namespace): # ── Data ────────────────────────────────────────────────────────────────── parser.add_argument("--cache_path", type=str, default="laion_hf_dataset/train", help="Path to HuggingFace dataset (Arrow format from 05_build_hf_dataset.py).") - parser.add_argument("--latent_dir", type=str, default="laion_latents/laion_latents", + parser.add_argument("--latent_dir", type=str, default="laion_latents", help="Directory of pre-cached .npy latent files (v1 latents are reusable).") parser.add_argument("--val_size", type=int, default=500) parser.add_argument("--latent_fraction", type=float, default=1.0, diff --git a/src/inference.py b/src/inference.py index f0b0f7d..2c28b99 100644 --- a/src/inference.py +++ b/src/inference.py @@ -160,24 +160,30 @@ def load_model(checkpoint_path: str, device: torch.device): @torch.no_grad() def generate( - prompts: list, + prompts: list, vae, text_encoder, unet, tokenizer, scheduler, - device: torch.device, - num_steps: int = 50, - guidance_scale: float = 7.5, - seed: int = 42, - height: int = 512, - width: int = 512, + device: torch.device, + num_steps: int = 50, + guidance_scale: float = 7.5, + seed: int = 42, + height: int = 512, + width: int = 512, + negative_prompts: list = None, ) -> list: assert height % 8 == 0 and width % 8 == 0 batch_size = len(prompts) latent_h = height // 8 latent_w = width // 8 + if negative_prompts is None: + negative_prompts = [""] * batch_size + elif len(negative_prompts) == 1 and batch_size > 1: + negative_prompts = negative_prompts * batch_size + # Encode text def encode(texts): tok = tokenizer( @@ -188,7 +194,7 @@ def encode(texts): return emb.float() cond_emb = encode(prompts) - uncond_emb = encode([""] * batch_size) + uncond_emb = encode(negative_prompts) ctx = torch.cat([uncond_emb, cond_emb], dim=0) # Initial noise — generate on CPU then move (MPS generator workaround) @@ -279,23 +285,22 @@ def main(): all_images = [] for i in range(0, len(prompts), args.batch_size): batch = prompts[i:i + args.batch_size] - # Apply negative prompt by replacing uncond with negative embedding - # (handled inside generate via the empty string default) print(f"Generating: {batch[0][:80]}") t0 = time.time() imgs = generate( - prompts = batch, - vae = vae, - text_encoder = text_encoder, - unet = unet, - tokenizer = tokenizer, - scheduler = scheduler, - device = device, - num_steps = args.steps, - guidance_scale = args.guidance, - seed = args.seed + i, - height = args.height, - width = args.width, + prompts = batch, + vae = vae, + text_encoder = text_encoder, + unet = unet, + tokenizer = tokenizer, + scheduler = scheduler, + device = device, + num_steps = args.steps, + guidance_scale = args.guidance, + seed = args.seed + i, + height = args.height, + width = args.width, + negative_prompts = [args.negative] if args.negative else None, ) elapsed = time.time() - t0 print(f" Done in {elapsed:.1f}s ({elapsed/args.steps:.2f}s/step)") diff --git a/src/model.py b/src/model.py index 3e7084a..c34d543 100644 --- a/src/model.py +++ b/src/model.py @@ -682,12 +682,14 @@ class DDIMScheduler: def __init__( self, - steps: int = 1000, - beta_start: float = 0.00085, - beta_end: float = 0.012, - schedule: str = "scaled_linear", + steps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.012, + schedule: str = "scaled_linear", + clamp_pred_x0: bool = False, ): self.num_train_timesteps = steps + self.clamp_pred_x0 = clamp_pred_x0 if schedule == "scaled_linear": betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, steps) ** 2 @@ -751,7 +753,8 @@ def step( # Step 1: Estimate clean latent x̂_0 from noisy x_t pred_x0 = (x_t - (1.0 - alpha_t).sqrt() * noise_pred) / alpha_t.sqrt() - pred_x0 = pred_x0.clamp(-1.0, 1.0) # prevent extreme values propagating + if self.clamp_pred_x0: + pred_x0 = pred_x0.clamp(-1.0, 1.0) # prevent extreme values propagating # Step 2: Direction from x̂_0 towards x_t (diffusion "velocity") dir_xt = (1.0 - alpha_prev).sqrt() * noise_pred diff --git a/src/train.py b/src/train.py index de68237..9b82f74 100644 --- a/src/train.py +++ b/src/train.py @@ -324,6 +324,7 @@ def train_epoch( uncond_text_emb=None, use_wandb=False, use_min_snr=True, min_snr_gamma=5.0, cfg_dropout=0.0, save_steps=0, ckpt_dir="checkpoints", best_loss=float("inf"), + memory_format: str = "channels_last", ) -> tuple[float, int]: ddp_unet.train() model.text_encoder.eval() @@ -337,7 +338,8 @@ def train_epoch( continue try: latents = batch["pixel_values"].to(device, dtype=torch.bfloat16, non_blocking=True) - latents = latents.contiguous(memory_format=torch.channels_last) + if memory_format == "channels_last": + latents = latents.contiguous(memory_format=torch.channels_last) ids = batch["input_ids"].to(device, non_blocking=True) mask = batch["attention_mask"].to(device, non_blocking=True) @@ -611,7 +613,8 @@ def main(rank, world_size, args): noise_scheduler = DDPMScheduler(steps=1000, beta_start=0.00085, beta_end=0.012, schedule="scaled_linear") model = StableDiffusionModel(vae, text_enc, unet, noise_scheduler).to(device) - model.unet = model.unet.to(memory_format=torch.channels_last) + if args.memory_format == "channels_last": + model.unet = model.unet.to(memory_format=torch.channels_last) noise_scheduler.to(device) # ── Dataset + unconditional embedding (precomputed once) ────────────────── @@ -689,6 +692,7 @@ def main(rank, world_size, args): use_wandb=args.use_wandb, use_min_snr=args.min_snr, min_snr_gamma=args.min_snr_gamma, cfg_dropout=args.cfg_dropout, save_steps=args.save_steps, ckpt_dir=args.ckpt_dir, best_loss=best_loss, + memory_format=args.memory_format, ) if avg_loss < best_loss: @@ -728,7 +732,7 @@ def main(rank, world_size, args): # Data parser.add_argument("--cache_path", type=str, default="laion_hf_dataset/train") - parser.add_argument("--latent_dir", type=str, default="laion_latents/laion_latents") + parser.add_argument("--latent_dir", type=str, default="laion_latents") parser.add_argument("--val_size", type=int, default=500) # Model @@ -751,6 +755,10 @@ def main(rank, world_size, args): parser.add_argument("--no-min-snr", dest="min_snr", action="store_false") parser.add_argument("--min_snr_gamma", type=float, default=5.0) parser.add_argument("--cfg_dropout", type=float, default=0.05, help="CFG dropout probability.") + parser.add_argument("--memory_format", type=str, default="channels_last", + choices=("channels_last", "contiguous"), + help="Memory format for UNet and latents. channels_last speeds up convs on NVIDIA GPUs. " + "Use 'contiguous' for AMD or Apple Silicon.") # Checkpointing parser.add_argument("--save_every", type=int, default=1) diff --git a/summary.md b/summary.md new file mode 100644 index 0000000..ca2bb9a --- /dev/null +++ b/summary.md @@ -0,0 +1,294 @@ +# Stable Diffusion 1.x — From-Scratch Training Summary + +**Project:** Custom Stable Diffusion model trained entirely from scratch +**Hardware:** 2× RTX 5090 (33.7 GB VRAM each), DDP + NCCL +**Platform:** RunPod +**Total training epochs:** 48 (in progress) +**Best loss achieved:** 0.0947 (epoch 16) + +--- + +## Model Architecture + +- **UNet:** 860M parameters, ch=320, 8 attention heads, 4-stage encoder/decoder (ch_mults=(1,2,4,4) → 320/640/1280/1280), attn_lvls=(1,2,3), 2 res_blks/stage, t_dim=320, ctx_dim=768 +- **Init:** Zero-initialized output projections (ResNet conv2, attention proj, cross-attn `to_out`, MLP final, SpatialTransformer `proj_out`, UNet `conv_out`) +- **Attention:** Self + cross-attention via PyTorch SDPA (Flash + mem-efficient, math kernel disabled) +- **VAE:** Frozen `stabilityai/sd-vae-ft-mse` (BF16 during training; uses `posterior.mean`, not sampled) +- **Text encoder:** Frozen `openai/clip-vit-large-patch14` (`last_hidden_state` → 77×768) +- **Scheduler:** DDPM scaled_linear betas (0.00085 → 0.012, 1000 steps) for training; DDIM for inference +- **Precision:** BF16 autocast, no GradScaler (Blackwell has native BF16, no FP16 underflow) +- **Memory format:** Currently `channels_last` in training (`SD_Train.py:341,606`); during Phase 4 (VGGFace2) `contiguous_format` was required on sm_120 — reverted back to `channels_last` once the build stabilized + +--- + +## Training Stack + +| Component | Detail | +|-----------|--------| +| Distributed | DDP + NCCL, 2 GPUs, `gradient_as_bucket_view=True`, `find_unused_parameters=False` | +| Optimizer | Fused AdamW (betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2; falls back to non-fused if unavailable) | +| LR schedule | `SequentialLR`: LinearLR warmup (start_factor=1e-2 → 1.0) + CosineAnnealingLR (eta_min = lr·1e-2) | +| Loss weighting | Min-SNR (default γ=5.0) | +| CFG training | Dropout (0.05 → 0.15 across phases) — uses precomputed unconditional embedding | +| Gradient checkpointing | Enabled on every `UNetResBlock` (~40% VRAM saving) via `torch.utils.checkpoint` (use_reentrant=False) | +| Gradient clipping | `clip_grad_norm_` max_norm=1.0 | +| `no_sync()` on accum | DDP all-reduce skipped on non-optimizer steps | +| EMA | GPU-resident shadow, decay=0.9999 with warmup `d = min(decay, (1+step)/(10+step))` | +| Batch size | 24/GPU × 2 GPUs × 2 grad_accum = 96 effective | +| Data | Latents loaded fully into RAM (`LATENT_FRACTION=1.0`, 16-thread `.npy` loader) | +| TF32 / SDP | TF32 enabled; Flash + mem-efficient SDP enabled, math SDP disabled | +| CUDA alloc | `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` | +| torch.compile | **Disabled** — conflicts with gradient checkpointing's dynamo-disabled forward | +| Monitoring | WandB (`stable-diffusion` project) | +| Fault tolerance | Atomic `.tmp` → `os.replace` saves; optional `--save_steps` mid-epoch checkpoints (e.g. every ~1500 steps) | + +--- + +## Training Phases + +### Phase 1 — LAION 1.3M Broad (Epochs 1–10) +- **Dataset:** LAION-2B-en aesthetic ≥6.5, 1,315,411 images +- **Filters:** aesthetic ≥6.5, CLIP sim ≥0.28, 512×512+, no watermark/NSFW +- **LR:** 1e-5 peak, 500 warmup steps +- **Epoch time:** ~3 hrs +- **Loss:** 0.22 → ~0.149 (epoch 7: 0.1507, epoch 8: 0.1508, epoch 9: 0.1249, epoch 10: 0.1247) +- **Interrupts:** Epochs 4, 9 stopped early via deliberate `KeyboardInterrupt` (manual termination — *not* VRAM/OOM failures). Peak VRAM observed ~25.3 GB on a 33.7 GB card. +- **Key learning:** Coarse structure, color statistics, spatial frequency + +### Phase 2 — LAION 400k High-Quality (Epochs 11–17) +- **Dataset:** LAION filtered aesthetic ≥7.5, watermark <0.15 (script default; ≤0.25 in blog), CLIP sim ≥0.30 → 213,458 images +- **LR:** 1e-5 peak (fresh restart), 500 warmup +- **Epoch time:** ~30 min +- **Loss:** 0.1260 (ep 11) → 0.1254 → 0.1254 → 0.1248 → 0.1026 (ep 15) → **0.0947 (epoch 16, best ever)** → 0.1083 (ep 17) +- **Interrupts:** Epochs 15 and 17 were stopped early via deliberate `KeyboardInterrupt` (manual termination — *not* OOM). +- **Key learning:** This was the single biggest quality jump in the entire project. High-aesthetic filtering > raw scale. +- **Note:** Epoch 16 used Min-SNR gamma=2.0 (too low, face geometry issues). Epoch 15 EMA weights recommended for inference. + +### Phase 3 — DiffusionDB + JourneyDB Mixed (Epochs 18–22) +- **Dataset:** ~482k images from DiffusionDB (500 shards) + JourneyDB (10 archives) → ~705k latents +- **LR:** 1e-5 (restart) +- **Epoch time:** ~1 hr +- **Loss:** ~0.120 → 0.1191 +- **Note:** Loss jumped from 0.0947 → 0.12 on domain shift (synthetic → real). Expected behavior. Epoch 22 used for inference. + +### Phase 4 — VGGFace2 Face Fine-Tuning (Epochs 23–29+) +- **Dataset:** VGGFace2, 51,786 images @ 512×512, template captions +- **LR:** 2e-6 (surgical fine-tune) +- **Warmup:** 200 steps +- **CFG dropout:** 0.15 (↑ from 0.05) +- **Epoch time:** ~2.3 hrs (999k samples) +- **Key fix needed:** `channels_last` → `contiguous_format` for sm_120 Blackwell +- **Result:** Face anatomy dramatically improved — bilateral eye symmetry, correct nose/mouth, skin texture + +### Phase 5 — COCO Full-Body Fine-Tuning (Epochs 30–38) +- **Dataset:** COCO detection-datasets, person bbox ≥55% height filter → 59,494 images +- **LR:** 1.5e-6 +- **Warmup:** 150 steps +- **Epoch time:** ~17 min +- **Result:** Background integration improved, body proportions corrected. Face remained slightly elongated. + +### Phase 6 — Mixed Consolidation (Epochs 39–42) +- **Dataset:** LAION 150k + VGGFace2 50k + COCO 58k = 250k mixed (60/20/20 ratio) +- **LR:** 1e-6 +- **Warmup:** 100 steps +- **Epoch time:** ~17 min +- **Loss:** ~0.12 at epoch 42 +- **Result:** Scene quality restored (forest/landscape/city near-perfect). Face and body gains preserved. + +### Phase 7 — Final Comprehensive Consolidation (Epochs 43–48, in progress) +- **Dataset:** LAION 213k + DM 250k + VGGFace2 51k + COCO 58k = 572k (37/44/9/10%) +- **LR:** 1e-6 +- **Warmup:** 150 steps +- **Epoch time:** ~40 min +- **Status:** Epochs 43–44 completed (loss 0.1202/0.1193). Resumed from epoch 044. Pod CUDA issues causing restart difficulties. + +--- + +## Loss History + +| Epoch | Loss | Phase | Notes | +|-------|------|-------|-------| +| 1 | ~0.220 | P1 | Start | +| 2 | 0.1583 | P1 | Large initial drop | +| 7 | 0.1507 | P1 | | +| 8 | 0.1508 | P1 | | +| 9 | 0.1249 | P1 | Resumed after crash | +| 10 | 0.1247 | P1 | End of broad phase | +| 11 | 0.1260 | P2 | Dataset switch (filtered LAION) | +| 12 | 0.1254 | P2 | | +| 13 | 0.1254 | P2 | | +| 14 | 0.1248 | P2 | | +| 15 | 0.1026 | P2 | (parallel run: 0.1030) | +| 16 | **0.0947** | P2 | **Best ever** (γ=2.0) | +| 17 | 0.1083 | P2 | γ back to 3.0 | +| 18 | 0.1207 | P3 | Domain shift jump (early DM mix) | +| 19 | 0.1037–0.1201 | P3 | branched runs | +| 20 | 0.1029–0.1201 | P3 | branched runs | +| 21 | 0.1030–0.1191 | P3 | | +| 22 | 0.1191 | P3 | End of DM phase | +| 38 | ~0.119 | P5 | | +| 42 | ~0.115 | P6 | Released as `sd_epoch_042.pt` | +| 43 | 0.1202 | P7 | | +| 44 | 0.1193 | P7 | | + +--- + +## Data Pipeline + +``` +01_download_metadata.py → snapshot_download LAION-2B-en aesthetic parquet shards +01b_download_diffusiondb.py → 500 zipped shards of DiffusionDB → 512×512 JPEG tars +01c_download_journeydb_images.py→ 10 JourneyDB tgz archives → 512×512 JPEG tars +02_filter_metadata.py → Quality filter (aesthetic, CLIP sim, dimensions, watermark, NSFW, dedup) +03_download_images.py → img2dataset → WebDataset tar shards (LAION pipeline) +03_build_hf_dataset.py → DiffusionDB/JourneyDB → Arrow HF dataset +04_preprocess_to_cache.py → Tar shards → image_key + 77-token CLIP IDs → parquet batches +05_build_hf_dataset.py → Parquet batches → Arrow HF dataset (train/val split) +encode_latents.py → 4-stage pipeline: shard prefetch → tar extract → + cv2 decode (16 workers, ~3× PIL) → DMA via pinned buffer → + VAE encode (BF16, torch.compile, ~20–30% boost) → fp16 .npy (32 KB each) +SD_Train.py → DDP training loop +inference.py → DDIM sampling with EMA + monkey-patched no-clamp `step()` +SD_ImageGen.py → CUDA inference with proper negative-prompt support +``` + +--- + +## Inference Settings + +| Parameter | Value | Notes | +|-----------|-------|-------| +| DDIM steps | 100 | 50 is insufficient for faces | +| CFG scale | 7.5 (scenes) / 8.5 (portraits) | 9.0+ causes artifacts | +| EMA weights | Required | Live weights are lower quality | +| DDIM clamp | **Removed at inference** | Original `pred_x0.clamp(-1,1)` still in `SD_Model.py:736`; `inference.py` monkey-patches `DDIMScheduler.step` to skip it (SD latents std ≈ 4.0) | +| Negative prompt | Use `SD_ImageGen.py` | `inference.py --negative` is parsed but not wired into `generate()` (still uses empty string) | +| Inference device | CUDA or MPS (Apple Silicon) | `inference.py` auto-detects via `get_device()` | + +--- + +## Key Bugs Fixed + +| Bug | Fix | +|-----|-----| +| `total_steps` used dataset length instead of loader length | Broke cosine LR scheduler | +| LR default was 1e-5 not 5e-5 | Corrected | +| V-prediction validation used DDPM alphas | Changed to DDIM alphas | +| DDIM latent clamp `[-1,1]` | Removed at inference via monkey-patch (training-time `SD_Model.py:736` still clamps) | +| EMA prefix stripping | Handles `module.`, `unet.`, `_orig_mod.`, `_fsdp_wrapped_module.` prefixes | +| `channels_last` on sm_120 | Temporarily switched to `contiguous_format` during Phase 4; current `SD_Train.py` uses `channels_last` again | +| Negative prompt `--negative` arg in `inference.py` | Argparsed but not threaded into `generate()`; use `SD_ImageGen.py` for true negative-prompt CFG | +| Validation grid quality | Still uses live UNet weights, not EMA | + +--- + +## Known Issues / Still Unfixed + +- `validate()` in `SD_Train.py` *does* call `ema.apply_shadow()` for the validation pass, but the saved `val_epoch_*.png` grids were generated before that fix and still reflect live-weight quality +- Training-time `DDIMScheduler.step` in `SD_Model.py:736` still clamps `pred_x0` to [-1, 1]; only inference scripts patch it out +- `inference.py --negative` parsed but never plumbed into `generate()`; use `SD_ImageGen.py` for negative prompts +- Face shape slightly elongated, eyes slightly too large (Phase 7 should fix) +- Left arm anatomy in full-body (partially fixed in Phase 5) +- `LATENT_DIR` default is `laion_latents/laion_latents` (nested) — verify before launching new phases + +--- + +## Infrastructure Lessons + +| Issue | Lesson | +|-------|--------| +| RunPod network bandwidth | ~100 KB/s public, ~300 MB/s datacenter (HuggingFace) | +| pip re-downloading torch | Always use `--no-deps` for package installs | +| HF cache on container disk | Set `HF_DATASETS_CACHE` and `HF_HOME` to `/workspace` | +| PyTorch + sm_120 (RTX 5090) | Requires PyTorch 2.6+ cu124, `channels_last` breaks | +| CUDA 13.2 pods | Incompatible with PyTorch 2.6+cu124 — use CUDA 12.4 image | +| Checkpoint saving | Atomic write (`.tmp` then `os.replace`) prevents corruption | +| Mid-epoch crashes | `save_steps` parameter + batch fast-forward for fault tolerance (also covers deliberate `KeyboardInterrupt` resumes) | +| Latent storage | Pre-encode to .npy (32 KB each), load all to RAM at start | + +--- + +## Dataset Sources Used + +| Dataset | Images | Used For | +|---------|--------|---------| +| LAION-2B-en aesthetic | 1.3M / 213k filtered | Phases 1, 2, 6, 7 | +| DiffusionDB | ~205k | Phase 3, 7 | +| JourneyDB | ~277k | Phase 3, 7 | +| VGGFace2 | 51–159k | Phases 4, 6, 7 | +| COCO (detection-datasets) | 59k | Phases 5, 6, 7 | + +--- + +## Output Quality at Epoch 42 + +| Category | Quality | Status | +|----------|---------|--------| +| Forest / nature | 9.5/10 | ✅ Excellent | +| Landscapes | 9/10 | ✅ Excellent | +| Cyberpunk city | 8.5/10 | ✅ Very good | +| Vehicles (car) | 7/10 | ✅ Good | +| Portrait faces | 7/10 | 🔧 Eyes slightly large | +| Full body | 6/10 | 🔧 Arm anatomy | +| Animals | 6/10 | ⚠️ Face anatomy weak | +| Architecture | 7/10 | ✅ Mostly good | +| Food | 5/10 | ⚠️ Category confusion | + +--- + +## Recommended Next Steps (Post Phase 7) + +1. **Rectified flow fine-tuning** — highest effort-to-impact for field relevance +2. **LCM distillation** — enables 1–4 step generation +3. **ControlNet** — adds spatial conditioning +4. **SD_Model_v2.py** — MM-DiT backbone, dual CLIP-L + OpenCLIP-bigG, rectified flow (already designed) + +--- + +## File Structure + +### Repo layout (`/Users/atandrabharati/Desktop/Computer Vision/Stable Diffusion/`) +``` +├── SD_Model.py # UNet + VAE/CLIP wrappers + DDPM/DDIM schedulers +├── SD_Train.py # 2× RTX 5090 DDP + BF16 training loop +├── SD_ImageGen.py # CUDA inference (full negative-prompt CFG support) +├── inference.py # Apple-Silicon / CUDA inference (DDIM clamp monkey-patched) +├── encode_latents.py # 4-stage VAE → .npy pipeline +├── 01_download_metadata.py # LAION parquet snapshot +├── 01b_download_diffusiondb.py # DiffusionDB shards → 512×512 tar +├── 01c_download_journeydb_images.py # JourneyDB archives → 512×512 tar +├── 02_filter_metadata.py # aesthetic / CLIP / watermark / NSFW / dedup filter +├── 03_download_images.py # img2dataset LAION downloader +├── 03_build_hf_dataset.py # DiffusionDB/JourneyDB HF dataset +├── 04_preprocess_to_cache.py # Tars → parquet (image_key + tokens) +├── 05_build_hf_dataset.py # Parquet → Arrow HF dataset +├── sd_epoch_042.pt # Released checkpoint (~12.5 GB) +├── blog_post.md # Public write-up +├── summary.md # (this file) +├── sd-val-imgs/ # val_epoch_001..043.png + ad-hoc test grids +├── sd-logs/ # Captured `training.log` / `output*.log` runs +├── sd-test-imgs/ # Inference smoke tests +├── generated_images/ # Curated final renders (epoch 42) +├── Training Screenshots/ # WandB / terminal screenshots + claude notes +├── diffusion/ # local venv (Python 3.12) +├── .vscode/ # Pyright off, autoimport on +└── __pycache__/ # cached SD_Model bytecode +``` + +### Training environment (`/workspace/StableDiffusion/` on RunPod) +``` +/workspace/StableDiffusion/ +├── SD_Model.py / SD_Train.py / inference.py / encode_latents.py / 01-05_*.py +├── checkpoints/ # Epoch checkpoints (~12 GB each) +├── outputs/ # Validation grids per epoch +├── laion_latents/ # 213k LAION .npy latents +├── vggface_latents/ # 51k VGGFace2 .npy latents +├── coco_latents/ # 59k COCO .npy latents +├── dm_latents/ # 705k DiffusionDB+JourneyDB .npy latents +├── p7_latents/ # 572k mixed symlinks for Phase 7 +├── laion_hf_dataset/ # Arrow dataset +├── vggface_hf_dataset_v2/ # Arrow dataset (corrected keys) +├── coco_hf_dataset/ # Arrow dataset +├── dm_hf_dataset/ # Arrow dataset +└── p7_hf_dataset/ # Mixed Arrow dataset for Phase 7 +``` diff --git a/tests/test_ddim_step.py b/tests/test_ddim_step.py new file mode 100644 index 0000000..d825dc4 --- /dev/null +++ b/tests/test_ddim_step.py @@ -0,0 +1,97 @@ +""" +Test that DDIM with eta=0 preserves latent variance through the denoising +trajectory (i.e., the denoising process doesn't collapse or explode). + +Run: + python -m pytest tests/test_ddim_step.py -v + # or directly: + python tests/test_ddim_step.py +""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src")) + +import torch + + +def test_ddim_variance_preservation(): + """ + Verify that DDIM doesn't produce extreme latent shifts. + The std of latents should stay within reasonable bounds throughout + the denoising trajectory (not collapse to 0 or explode). + """ + from model import DDIMScheduler + + ddim = DDIMScheduler(steps=1000, clamp_pred_x0=False) + num_steps = 50 + ddim.set_timesteps(num_steps, device="cpu") + + # Start from realistic noise (pure Gaussian) + x_t = torch.randn(2, 4, 16, 16) # half res for speed + initial_std = x_t.std().item() + + stds = [] + for i, t in enumerate(ddim.timesteps): + noise_pred = torch.randn_like(x_t) # simulate UNet predicting near-Gaussian + x_t = ddim.step(noise_pred, t, x_t, eta=0.0) + stds.append(x_t.std().item()) + + final_std = stds[-1] + std_ratio = final_std / initial_std + + # With random noise predictions, the latents shouldn't collapse: + # a perfectly denoised image would have much lower std than noise, + # but with random predictions the variance should stay bounded. + assert 0.1 < std_ratio < 3.0, ( + f"DDIM variance collapsed/extreme: initial_std={initial_std:.4f}, " + f"final_std={final_std:.4f}, ratio={std_ratio:.4f}" + ) + print(f" ✓ DDIM variance preservation (std ratio: {std_ratio:.4f})") + + +def test_ddim_determinism(): + """Same seed + same inputs should produce identical outputs.""" + from model import DDIMScheduler + + ddim = DDIMScheduler(steps=1000, clamp_pred_x0=False) + ddim.set_timesteps(10, device="cpu") + + torch.manual_seed(0) + x_t = torch.randn(1, 4, 8, 8) + noise = torch.randn_like(x_t) + + out1 = ddim.step(noise.clone(), ddim.timesteps[0], x_t.clone(), eta=0.0) + out2 = ddim.step(noise.clone(), ddim.timesteps[0], x_t.clone(), eta=0.0) + + assert torch.equal(out1, out2), "DDIM should be deterministic with eta=0" + print(" ✓ DDIM determinism") + + +def test_ddim_stochasticity(): + """eta > 0 should produce different (non-deterministic) outputs.""" + from model import DDIMScheduler + + ddim = DDIMScheduler(steps=1000, clamp_pred_x0=False) + ddim.set_timesteps(10, device="cpu") + + torch.manual_seed(42) + x_t = torch.randn(1, 4, 8, 8) + noise = torch.randn_like(x_t) + + out_det = ddim.step(noise.clone(), ddim.timesteps[0], x_t.clone(), eta=0.0) + + # eta=1.0 should differ due to random noise injection + out_stoch = ddim.step(noise.clone(), ddim.timesteps[0], x_t.clone(), eta=1.0) + + assert not torch.equal(out_det, out_stoch), "eta=1.0 should differ from eta=0.0" + print(" ✓ DDIM stochasticity (eta)") + + +if __name__ == "__main__": + print("Running DDIM tests...\n") + test_ddim_variance_preservation() + test_ddim_determinism() + test_ddim_stochasticity() + print("\n✓ All DDIM tests passed.") diff --git a/tests/test_unet_forward.py b/tests/test_unet_forward.py new file mode 100644 index 0000000..72635a7 --- /dev/null +++ b/tests/test_unet_forward.py @@ -0,0 +1,111 @@ +""" +CPU smoke test: UNet forward pass, parameter count, DDIM step. + +Run: + pip install torch transformers + python -m pytest tests/test_unet_forward.py -v + # or directly: + python tests/test_unet_forward.py +""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "src")) + +import torch + + +def test_ddpm_add_noise(): + from model import DDPMScheduler + + sched = DDPMScheduler(steps=1000) + x = torch.randn(2, 4, 8, 8) + t = torch.randint(0, 1000, (2,)) + x_t, noise = sched.add_noise(x, t) + assert x_t.shape == x.shape, f"DDPM add_noise shape mismatch: {x_t.shape} != {x.shape}" + print(" ✓ DDPMScheduler.add_noise") + + +def test_ddim_timesteps(): + from model import DDIMScheduler + + ddim = DDIMScheduler(steps=1000) + ddim.set_timesteps(4, device="cpu") + assert ddim.timesteps.shape[0] == 4, f"DDIM timesteps count wrong: {ddim.timesteps.shape[0]} != 4" + print(" ✓ DDIMScheduler.set_timesteps") + + +def test_ddim_step(): + from model import DDIMScheduler + + ddim = DDIMScheduler(steps=1000, clamp_pred_x0=False) + ddim.set_timesteps(4, device="cpu") + noise = torch.randn(1, 4, 8, 8) + x_t = torch.randn(1, 4, 8, 8) + t = ddim.timesteps[0] + x_prev = ddim.step(noise, t, x_t, eta=0.0) + assert x_prev.shape == x_t.shape, f"DDIM step shape mismatch: {x_prev.shape} != {x_t.shape}" + assert not torch.isnan(x_prev).any(), "DDIM step produced NaN" + print(" ✓ DDIMScheduler.step") + + +def test_ddim_clamp_opt_in(): + from model import DDIMScheduler + + ddim_clamp = DDIMScheduler(steps=1000, clamp_pred_x0=True) + ddim_noclamp = DDIMScheduler(steps=1000, clamp_pred_x0=False) + ddim_clamp.set_timesteps(4, device="cpu") + ddim_noclamp.set_timesteps(4, device="cpu") + + noise = torch.randn(1, 4, 8, 8) * 10 # extreme noise + x_t = torch.randn(1, 4, 8, 8) * 10 + t = ddim_clamp.timesteps[0] + + out_clamp = ddim_clamp.step(noise, t, x_t, eta=0.0) + out_noclamp = ddim_noclamp.step(noise, t, x_t, eta=0.0) + assert out_clamp.shape == out_noclamp.shape + assert not torch.equal(out_clamp, out_noclamp), "clamp=True vs False should differ with extreme inputs" + print(" ✓ DDIMScheduler clamp opt-in") + + +def test_unet_forward(): + from model import UNetModel + + unet = UNetModel( + in_ch=4, out_ch=4, ch=32, + res_blks=1, attn_lvls=(1,), + ch_mults=(1, 2), heads=2, + t_dim=32, ctx_dim=64, + ) + unet.eval() + with torch.no_grad(): + latent = torch.randn(1, 4, 8, 8) + t = torch.randint(0, 1000, (1,)) + ctx = torch.randn(1, 4, 64) + out = unet(latent, t, ctx) + + assert out.shape == latent.shape, f"UNet output shape {out.shape} != {latent.shape}" + total_params = sum(p.numel() for p in unet.parameters()) + print(f" ✓ UNet forward pass ({total_params:,} params)") + + +def test_config_import(): + sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "configs")) + from config import SDConfig + cfg = SDConfig() + assert cfg.model.ch == 320 + assert cfg.training.ema_decay == 0.9999 + assert cfg.scheduler.guidance_scale == 7.5 + print(" ✓ Config import") + + +if __name__ == "__main__": + print("Running smoke tests...\n") + test_ddpm_add_noise() + test_ddim_timesteps() + test_ddim_step() + test_ddim_clamp_opt_in() + test_unet_forward() + test_config_import() + print("\n✓ All smoke tests passed.")