Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Hugging Face Hub (write scope) — used by scripts/download_checkpoint.py and
# any local fine-tune resume helpers that pull/push from atandra2000/sd-from-scratch-v1.
HF_TOKEN=

# Weights & Biases — used by SD_Train.py for training-time monitoring.
WANDB_API_KEY=

# Optional: override paths if you don't use the defaults
# CHECKPOINT_DIR=./checkpoints
# LATENT_DIR=./laion_latents
# HF_HOME=./.hf-cache
46 changes: 44 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,18 +1,37 @@
# Python
__pycache__/
*.pyc
*.pyo
*.pyd
.Python
*.egg-info/
dist/
build/
.pytest_cache/
.ruff_cache/
.mypy_cache/
.coverage
htmlcov/

# Virtual environments
.env
.env.local
.env.*.local
.venv
venv/
diffusion/

# Secrets (defensive — never commit raw token files)
secrets/
*.pem
*.key

# Checkpoints and model weights
*.pt
*.pth
*.ckpt
*.safetensors
*.bin
checkpoints/
outputs/

Expand All @@ -21,27 +40,50 @@ outputs/
*.parquet
*.tar
*.zip
*.tgz
*.h5
*.hdf5
*.arrow
*.lance
laion_*/
vggface_*/
coco_*/
dm_*/
p7_*/
/workspace/

# Training logs and wandb
# HF/HuggingFace caches
.hf-cache/
.cache/huggingface/

# Training logs and W&B
*.log
logs/
wandb/
training.log
runs/
tb_logs/

# Generated images (exclude training validation grids, keep curated samples)
val_epoch_*.png
results/*.png
!results/samples/
!results/samples/*.png
!assets/*.png
!docs/images/

# Jupyter
.ipynb_checkpoints/
*.ipynb

# IDE
# IDE / editor
.idea/
.vscode/
*.swp
*.swo
*~

# OS
.DS_Store
Thumbs.db
desktop.ini
39 changes: 39 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
cff-version: 1.2.0
title: "Stable Diffusion from Scratch (SD-From-Scratch v1)"
message: "If you use this software or the released checkpoint, please cite it as below."
type: software
authors:
- given-names: Atandra
family-names: Bharati
email: atandrabharati@gmail.com
repository-code: "https://github.com/atandra2000/StableDiffusion"
url: "https://github.com/atandra2000/StableDiffusion"
abstract: >-
A Stable Diffusion 1.x-class latent diffusion model implemented and
trained entirely from scratch in PyTorch on 2× RTX 5090 (Blackwell) GPUs.
Provides the full UNet (~860M params), DDPM/DDIM schedulers, LAION data
pipeline, dual-GPU latent encoding, DDP+BF16 training loop, and
CUDA/MPS inference. Released checkpoint sd_epoch_042.pt is hosted on
Hugging Face Hub at atandra2000/sd-from-scratch-v1.
license: MIT
version: "1.0.0"
date-released: 2026-06-02
keywords:
- stable-diffusion
- latent-diffusion
- diffusion-models
- pytorch
- text-to-image
- generative-ai
- from-scratch
- rtx-5090
- blackwell
- ddp
preferred-citation:
type: software
authors:
- given-names: Atandra
family-names: Bharati
title: "SD-From-Scratch v1: A Stable-Diffusion-class latent diffusion model trained from scratch on dual RTX 5090s"
year: 2026
url: "https://huggingface.co/atandra2000/sd-from-scratch-v1"
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2026 Atandra Bharati

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
2 changes: 1 addition & 1 deletion src/SD_ImageGen.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def load_model(checkpoint_path: str, device: torch.device) -> StableDiffusionMod
ema = EMA(model.unet, decay=0.9999)
ema.load_state_dict(ckpt["ema_state_dict"])
# Apply EMA shadow weights to the model
ema.apply_shadow()
ema.apply_shadow(model.unet)
logger.info(f"✅ EMA weights applied (step {ema.step_count}) — better image quality")

elif "model_state_dict" in ckpt:
Expand Down
17 changes: 8 additions & 9 deletions src/SD_Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,12 +660,14 @@ class DDIMScheduler:

def __init__(
self,
steps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
schedule: str = "scaled_linear",
steps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
schedule: str = "scaled_linear",
clamp_pred_x0: bool = False,
):
self.num_train_timesteps = steps
self.clamp_pred_x0 = clamp_pred_x0

if schedule == "scaled_linear":
betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, steps) ** 2
Expand All @@ -678,10 +680,6 @@ def __init__(
self.timesteps: Optional[torch.Tensor] = None
self.num_inference_steps: Optional[int] = None

def to(self, device: torch.device):
self.alphas_cumprod = self.alphas_cumprod.to(device)
return self

def set_timesteps(self, num_steps: int, device: torch.device):
"""
Compute the subset of evenly-spaced timesteps used for inference.
Expand Down Expand Up @@ -733,7 +731,8 @@ def step(

# Step 1: Estimate clean latent x̂_0 from noisy x_t
pred_x0 = (x_t - (1.0 - alpha_t).sqrt() * noise_pred) / alpha_t.sqrt()
pred_x0 = pred_x0.clamp(-1.0, 1.0) # prevent extreme values propagating
if self.clamp_pred_x0:
pred_x0 = pred_x0.clamp(-1.0, 1.0)

# Step 2: Direction from x̂_0 towards x_t (diffusion "velocity")
dir_xt = (1.0 - alpha_prev).sqrt() * noise_pred
Expand Down
14 changes: 11 additions & 3 deletions src/SD_Train.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def train_epoch(
uncond_text_emb=None, use_wandb=False, use_min_snr=True,
min_snr_gamma=5.0, cfg_dropout=0.0,
save_steps=0, ckpt_dir="checkpoints", best_loss=float("inf"),
memory_format: str = "channels_last",
) -> tuple[float, int]:
ddp_unet.train()
model.text_encoder.eval()
Expand All @@ -338,7 +339,8 @@ def train_epoch(
continue
try:
latents = batch["pixel_values"].to(device, dtype=torch.bfloat16, non_blocking=True)
latents = latents.contiguous(memory_format=torch.channels_last)
if memory_format == "channels_last":
latents = latents.contiguous(memory_format=torch.channels_last)
ids = batch["input_ids"].to(device, non_blocking=True)
mask = batch["attention_mask"].to(device, non_blocking=True)

Expand Down Expand Up @@ -603,7 +605,8 @@ def main(rank, world_size, args):

noise_scheduler = DDPMScheduler(steps=1000, beta_start=0.00085, beta_end=0.012, schedule="scaled_linear")
model = StableDiffusionModel(vae, text_enc, unet, noise_scheduler).to(device)
model.unet = model.unet.to(memory_format=torch.channels_last)
if args.memory_format == "channels_last":
model.unet = model.unet.to(memory_format=torch.channels_last)
noise_scheduler.to(device)

# ── Dataset + unconditional embedding (precomputed once) ──────────────────
Expand Down Expand Up @@ -681,6 +684,7 @@ def main(rank, world_size, args):
use_wandb=args.use_wandb, use_min_snr=args.min_snr, min_snr_gamma=args.min_snr_gamma,
cfg_dropout=args.cfg_dropout,
save_steps=args.save_steps, ckpt_dir=args.ckpt_dir, best_loss=best_loss,
memory_format=args.memory_format,
)

if avg_loss < best_loss:
Expand Down Expand Up @@ -720,7 +724,7 @@ def main(rank, world_size, args):

# Data
parser.add_argument("--cache_path", type=str, default="laion_hf_dataset/train")
parser.add_argument("--latent_dir", type=str, default="laion_latents/laion_latents")
parser.add_argument("--latent_dir", type=str, default="laion_latents")
parser.add_argument("--val_size", type=int, default=500)

# Model
Expand All @@ -743,6 +747,10 @@ def main(rank, world_size, args):
parser.add_argument("--no-min-snr", dest="min_snr", action="store_false")
parser.add_argument("--min_snr_gamma", type=float, default=5.0)
parser.add_argument("--cfg_dropout", type=float, default=0.05, help="CFG dropout probability.")
parser.add_argument("--memory_format", type=str, default="channels_last",
choices=("channels_last", "contiguous"),
help="Memory format for UNet and latents. channels_last speeds up convs on NVIDIA GPUs. "
"Use 'contiguous' for AMD or Apple Silicon.")

# Checkpointing
parser.add_argument("--save_every", type=int, default=1, help="Save checkpoint every N epochs.")
Expand Down
2 changes: 1 addition & 1 deletion src/SD_Train_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1485,7 +1485,7 @@ def main(rank: int, world_size: int, args: argparse.Namespace):
# ── Data ──────────────────────────────────────────────────────────────────
parser.add_argument("--cache_path", type=str, default="laion_hf_dataset/train",
help="Path to HuggingFace dataset (Arrow format from 05_build_hf_dataset.py).")
parser.add_argument("--latent_dir", type=str, default="laion_latents/laion_latents",
parser.add_argument("--latent_dir", type=str, default="laion_latents",
help="Directory of pre-cached .npy latent files (v1 latents are reusable).")
parser.add_argument("--val_size", type=int, default=500)
parser.add_argument("--latent_fraction", type=float, default=1.0,
Expand Down
49 changes: 27 additions & 22 deletions src/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,24 +160,30 @@ def load_model(checkpoint_path: str, device: torch.device):

@torch.no_grad()
def generate(
prompts: list,
prompts: list,
vae,
text_encoder,
unet,
tokenizer,
scheduler,
device: torch.device,
num_steps: int = 50,
guidance_scale: float = 7.5,
seed: int = 42,
height: int = 512,
width: int = 512,
device: torch.device,
num_steps: int = 50,
guidance_scale: float = 7.5,
seed: int = 42,
height: int = 512,
width: int = 512,
negative_prompts: list = None,
) -> list:
assert height % 8 == 0 and width % 8 == 0
batch_size = len(prompts)
latent_h = height // 8
latent_w = width // 8

if negative_prompts is None:
negative_prompts = [""] * batch_size
elif len(negative_prompts) == 1 and batch_size > 1:
negative_prompts = negative_prompts * batch_size

# Encode text
def encode(texts):
tok = tokenizer(
Expand All @@ -188,7 +194,7 @@ def encode(texts):
return emb.float()

cond_emb = encode(prompts)
uncond_emb = encode([""] * batch_size)
uncond_emb = encode(negative_prompts)
ctx = torch.cat([uncond_emb, cond_emb], dim=0)

# Initial noise — generate on CPU then move (MPS generator workaround)
Expand Down Expand Up @@ -279,23 +285,22 @@ def main():
all_images = []
for i in range(0, len(prompts), args.batch_size):
batch = prompts[i:i + args.batch_size]
# Apply negative prompt by replacing uncond with negative embedding
# (handled inside generate via the empty string default)
print(f"Generating: {batch[0][:80]}")
t0 = time.time()
imgs = generate(
prompts = batch,
vae = vae,
text_encoder = text_encoder,
unet = unet,
tokenizer = tokenizer,
scheduler = scheduler,
device = device,
num_steps = args.steps,
guidance_scale = args.guidance,
seed = args.seed + i,
height = args.height,
width = args.width,
prompts = batch,
vae = vae,
text_encoder = text_encoder,
unet = unet,
tokenizer = tokenizer,
scheduler = scheduler,
device = device,
num_steps = args.steps,
guidance_scale = args.guidance,
seed = args.seed + i,
height = args.height,
width = args.width,
negative_prompts = [args.negative] if args.negative else None,
)
elapsed = time.time() - t0
print(f" Done in {elapsed:.1f}s ({elapsed/args.steps:.2f}s/step)")
Expand Down
13 changes: 8 additions & 5 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,12 +682,14 @@ class DDIMScheduler:

def __init__(
self,
steps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
schedule: str = "scaled_linear",
steps: int = 1000,
beta_start: float = 0.00085,
beta_end: float = 0.012,
schedule: str = "scaled_linear",
clamp_pred_x0: bool = False,
):
self.num_train_timesteps = steps
self.clamp_pred_x0 = clamp_pred_x0

if schedule == "scaled_linear":
betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, steps) ** 2
Expand Down Expand Up @@ -751,7 +753,8 @@ def step(

# Step 1: Estimate clean latent x̂_0 from noisy x_t
pred_x0 = (x_t - (1.0 - alpha_t).sqrt() * noise_pred) / alpha_t.sqrt()
pred_x0 = pred_x0.clamp(-1.0, 1.0) # prevent extreme values propagating
if self.clamp_pred_x0:
pred_x0 = pred_x0.clamp(-1.0, 1.0) # prevent extreme values propagating

# Step 2: Direction from x̂_0 towards x_t (diffusion "velocity")
dir_xt = (1.0 - alpha_prev).sqrt() * noise_pred
Expand Down
Loading
Loading