diff --git a/articles/batch-normalization.md b/articles/batch-normalization.md new file mode 100644 index 000000000..20d048553 --- /dev/null +++ b/articles/batch-normalization.md @@ -0,0 +1,155 @@ +## Prerequisites + +Before attempting this problem, you should be comfortable with: + +- **Layer Normalization** - You just implemented normalizing across features within each sample. Batch normalization normalizes across the batch for each feature instead. +- **Mean and Variance** - Computing $\mu = \frac{1}{N}\sum x_i$ and $\sigma^2 = \frac{1}{N}\sum(x_i - \mu)^2$ along the batch axis +- **Training vs Inference** - Batch norm behaves differently in training (uses batch statistics) and inference (uses running statistics). This dual behavior is the trickiest part. + +--- + +## Concept + +Layer normalization normalizes across features within each sample (axis=1). Batch normalization flips the axis: it normalizes across the batch for each feature (axis=0). + +For a batch of $N$ samples, each with $D$ features: + +**During training:** +1. Compute mean and variance for each feature across the batch +2. Normalize: $\hat{x} = \frac{x - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}$ +3. Scale and shift: $y = \gamma \cdot \hat{x} + \beta$ (learned parameters) +4. Update running statistics via exponential moving average + +**During inference:** +Use the accumulated running statistics instead of batch statistics. This is critical because at inference time you might process a single sample, making batch statistics meaningless. + +The running statistics update uses momentum $m$: + +$$\text{running\_mean} = (1 - m) \cdot \text{running\_mean} + m \cdot \mu_B$$ + +This exponential moving average gives recent batches more weight while smoothing over the entire training history. + +Why does batch norm help? It reduces "internal covariate shift": as earlier layers update during training, the distribution of inputs to later layers constantly changes. Batch norm re-centers and re-scales these distributions, allowing higher learning rates and faster convergence. + +--- + +## Solution + +### Intuition + +In training mode: compute mean and variance along axis=0, normalize, apply affine transform, update running stats. In inference mode: skip the batch statistics entirely and use the running mean/variance that were accumulated during training. + +### Implementation + +::tabs-start +```python +import numpy as np +from typing import Tuple, List + + +class Solution: + def batch_norm(self, x: List[List[float]], gamma: List[float], beta: List[float], + running_mean: List[float], running_var: List[float], + momentum: float, eps: float, training: bool) -> Tuple[List[List[float]], List[float], List[float]]: + x = np.array(x) + gamma = np.array(gamma) + beta = np.array(beta) + running_mean = np.array(running_mean, dtype=np.float64) + running_var = np.array(running_var, dtype=np.float64) + + if training: + batch_mean = np.mean(x, axis=0) + batch_var = np.var(x, axis=0) + x_hat = (x - batch_mean) / np.sqrt(batch_var + eps) + running_mean = (1 - momentum) * running_mean + momentum * batch_mean + running_var = (1 - momentum) * running_var + momentum * batch_var + else: + x_hat = (x - running_mean) / np.sqrt(running_var + eps) + + out = gamma * x_hat + beta + return (np.round(out, 4).tolist(), np.round(running_mean, 4).tolist(), np.round(running_var, 4).tolist()) +``` +::tabs-end + + +### Walkthrough + +Given a batch of 3 samples with 4 features, `gamma = [1,1,1,1]`, `beta = [0,0,0,0]`, training=True: + +``` +x = [[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12]] +``` + +| Step | Computation | Result | +|---|---|---| +| Batch mean (axis=0) | $[1+5+9, 2+6+10, ...] / 3$ | $\mu_B = [5, 6, 7, 8]$ | +| Batch variance | $[(1-5)^2 + (5-5)^2 + (9-5)^2] / 3, ...$ | $\sigma_B^2 = [10.667, 10.667, 10.667, 10.667]$ | +| Normalize row 1 | $(1 - 5) / \sqrt{10.667}$ | $-1.2247$ | +| Normalize row 2 | $(5 - 5) / \sqrt{10.667}$ | $0.0$ | +| Normalize row 3 | $(9 - 5) / \sqrt{10.667}$ | $1.2247$ | +| Update running mean | $0.9 \cdot [0,0,0,0] + 0.1 \cdot [5,6,7,8]$ | $[0.5, 0.6, 0.7, 0.8]$ | +| Update running var | $0.9 \cdot [1,1,1,1] + 0.1 \cdot [10.667,...]$ | $[1.9667, 1.9667, 1.9667, 1.9667]$ | + +With `gamma = [2, 0.5, 1.5]` and `beta = [1, -1, 0.5]`, the affine transform would scale and shift each feature independently. + +### Time & Space Complexity + +- Time: $O(N \cdot D)$ where $N$ is batch size and $D$ is features +- Space: $O(N \cdot D)$ for the normalized output (plus $O(D)$ for running statistics) + +--- + +## Common Pitfalls + +### Using Batch Statistics During Inference + +During inference (`training=False`), you must use the running statistics, not the current batch. Using batch statistics at inference time means your model's output depends on what other samples are in the batch. + +::tabs-start +```python +# Wrong: always using batch statistics +batch_mean = np.mean(x, axis=0) +batch_var = np.var(x, axis=0) +x_hat = (x - batch_mean) / np.sqrt(batch_var + eps) + +# Correct: check training flag +if training: + batch_mean = np.mean(x, axis=0) + batch_var = np.var(x, axis=0) + x_hat = (x - batch_mean) / np.sqrt(batch_var + eps) +else: + x_hat = (x - running_mean) / np.sqrt(running_var + eps) +``` +::tabs-end + + +### Normalizing Along the Wrong Axis + +Batch norm normalizes across the batch (axis=0), not across features (axis=1). Normalizing across features gives you layer normalization instead. + +::tabs-start +```python +# Wrong: this is layer normalization (axis=1) +batch_mean = np.mean(x, axis=1, keepdims=True) + +# Correct: batch normalization normalizes across samples (axis=0) +batch_mean = np.mean(x, axis=0) +``` +::tabs-end + + +--- + +## In the GPT Project + +This becomes `model/batch_normalization.py`. While modern transformers (GPT, LLaMA) use **layer normalization** rather than batch normalization, understanding batch norm is essential context. Batch norm was the breakthrough that made training deep CNNs practical (ResNets), and the distinction between batch vs layer normalization is a common interview question. The key tradeoff: batch norm depends on batch size and behaves differently at train/eval time, while layer norm is batch-independent and consistent. + +--- + +## Key Takeaways + +- Batch normalization normalizes across the batch for each feature (axis=0), while layer normalization normalizes across features for each sample (axis=1). The axis flip changes everything about when and where each technique works best. +- The train/inference split is the hardest part: during training you use live batch statistics and update running estimates. During inference you use those accumulated running estimates because batch statistics from a single sample are meaningless. +- Running statistics use an exponential moving average controlled by momentum. This smooths over the randomness of individual mini-batches while tracking the distribution as the network learns. diff --git a/articles/multi-headed-self-attention.md b/articles/multi-headed-self-attention.md index e5599b8ec..3b56eedd7 100644 --- a/articles/multi-headed-self-attention.md +++ b/articles/multi-headed-self-attention.md @@ -17,10 +17,11 @@ If the total attention dimension is $d$ and we have $h$ heads, each head operate 1. Create $h$ heads, each with attention dimension $d/h$. 2. Run each head independently on the same input. 3. Concatenate all head outputs along the feature dimension. +4. Apply a learned output projection $W^O$ to combine the heads. Why does this help? Different heads learn different things. In practice, researchers have observed that some heads learn syntactic relationships (attending to the previous word), some learn semantic relationships (attending to the subject of a sentence), and some learn positional patterns (attending to nearby tokens). A single head would have to compromise between all these patterns; multiple heads can specialize. -The output shape is $(B, T, d)$, the same as a single head with the full dimension. This makes multi-head attention a drop-in replacement for single-head attention. +The final output projection $W^O$ (a linear layer of size $d \times d$) lets the model learn how to best combine information from all heads. The output shape is $(B, T, d)$, the same as a single head with the full dimension, making multi-head attention a drop-in replacement. --- @@ -28,7 +29,7 @@ The output shape is $(B, T, d)$, the same as a single head with the full dimensi ### Intuition -Create a list of `SingleHeadAttention` modules, each with attention dimension `attention_dim // num_heads`. Run each head on the same input. Concatenate outputs along the last dimension. +Create a list of `SingleHeadAttention` modules, each with attention dimension `attention_dim // num_heads`. Run each head on the same input. Concatenate outputs along the last dimension. Apply a learned output projection ($W^O$) to the concatenated result. ### Implementation @@ -46,13 +47,14 @@ class MultiHeadedSelfAttention(nn.Module): self.att_heads = nn.ModuleList() for i in range(num_heads): self.att_heads.append(self.SingleHeadAttention(embedding_dim, attention_dim // num_heads)) + self.output_proj = nn.Linear(attention_dim, attention_dim, bias=False) def forward(self, embedded: TensorType[float]) -> TensorType[float]: head_outputs = [] for head in self.att_heads: head_outputs.append(head(embedded)) concatenated = torch.cat(head_outputs, dim = 2) - return torch.round(concatenated, decimals=4) + return torch.round(self.output_proj(concatenated), decimals=4) class SingleHeadAttention(nn.Module): def __init__(self, embedding_dim: int, attention_dim: int): @@ -92,8 +94,9 @@ For `embedding_dim = 8`, `attention_dim = 8`, `num_heads = 4`, sequence of 3 tok | Head 2 | $(B, 3, 8)$ | $(B, 3, 2)$ | $(B, 3, 2)$ | | Head 3 | $(B, 3, 8)$ | $(B, 3, 2)$ | $(B, 3, 2)$ | | Concat | 4 outputs along dim=2 | | $(B, 3, 8)$ | +| $W^O$ | Linear projection $d \to d$ | | $(B, 3, 8)$ | -Each head projects from 8 to 2 dimensions (8/4 = 2), and concatenation restores the full 8 dimensions. +Each head projects from 8 to 2 dimensions (8/4 = 2), concatenation restores the full 8 dimensions, and $W^O$ learns how to best combine the heads' outputs. ### Time & Space Complexity @@ -148,6 +151,6 @@ This becomes `model/multi_head_attention.py`. The GPT model uses multi-headed at ## Key Takeaways -- Multi-headed attention runs several attention heads in parallel, each specializing in different relationship patterns, without increasing total computation over a single large head. +- Multi-headed attention runs several attention heads in parallel, each specializing in different relationship patterns, with a learned output projection ($W^O$) that combines their outputs. - Each head operates on a $d/h$ dimensional subspace, and concatenation reconstructs the full dimension, making it a drop-in replacement for single-head attention. - Using `nn.ModuleList` (not a plain Python list) is essential so PyTorch can track and update each head's parameters during training. diff --git a/articles/multi-layer-backpropagation.md b/articles/multi-layer-backpropagation.md new file mode 100644 index 000000000..b1f4bee10 --- /dev/null +++ b/articles/multi-layer-backpropagation.md @@ -0,0 +1,153 @@ +## Prerequisites + +Before attempting this problem, you should be comfortable with: + +- **Single-neuron backpropagation** - The chain rule through one neuron ($z \to \sigma \to L$). Now you're chaining through multiple layers, but the principle is identical. +- **ReLU activation** - Unlike sigmoid, ReLU's derivative is binary: 1 where $z > 0$, 0 elsewhere. This creates the "dead neuron" problem when $z \leq 0$ for all inputs. +- **Matrix multiplication** - Gradients through linear layers involve transposing weight matrices. Understanding $z = xW^T + b$ and its Jacobian is essential. + +--- + +## Concept + +Single-neuron backprop had three links in the chain rule: loss $\to$ activation $\to$ weights. A multi-layer network has more links but the same idea: multiply local derivatives as you walk backward from the loss. + +For a 2-layer MLP with ReLU: + +$$x \xrightarrow{W_1, b_1} z_1 \xrightarrow{\text{ReLU}} a_1 \xrightarrow{W_2, b_2} z_2 \xrightarrow{\text{MSE}} L$$ + +Each arrow is one step in the chain rule. Working backward from $L$: + +1. $\frac{\partial L}{\partial z_2}$ is the error signal from MSE +2. $\frac{\partial L}{\partial W_2}$ and $\frac{\partial L}{\partial b_2}$ use $a_1$ (the layer's input) +3. $\frac{\partial L}{\partial a_1}$ passes the gradient backward through $W_2$ +4. $\frac{\partial L}{\partial z_1}$ multiplies by the ReLU mask (binary: 1 or 0) +5. $\frac{\partial L}{\partial W_1}$ and $\frac{\partial L}{\partial b_1}$ use $x$ (the network's input) + +The ReLU derivative is the critical piece: where $z_1 > 0$, the gradient passes through unchanged. Where $z_1 \leq 0$, the gradient is zeroed out. This is why neurons can "die" during training: if a neuron's pre-activation is always negative, it permanently stops learning. + +--- + +## Solution + +### Intuition + +Run the forward pass to get all intermediate values ($z_1$, $a_1$, $z_2$), compute MSE loss, then walk backward applying the chain rule at each layer. Each layer's weight gradient is the outer product of the incoming error signal and the layer's input. + +### Implementation + +::tabs-start +```python +import numpy as np +from typing import List + + +class Solution: + def forward_and_backward(self, + x: List[float], + W1: List[List[float]], b1: List[float], + W2: List[List[float]], b2: List[float], + y_true: List[float]) -> dict: + x = np.array(x) + W1 = np.array(W1) + b1 = np.array(b1) + W2 = np.array(W2) + b2 = np.array(b2) + y_true = np.array(y_true) + + # Forward pass + z1 = x @ W1.T + b1 # pre-activation layer 1 + a1 = np.maximum(0, z1) # ReLU activation + z2 = a1 @ W2.T + b2 # output (predictions) + loss = np.mean((z2 - y_true) ** 2) + + # Backward pass + n = len(y_true) if y_true.ndim > 0 else 1 + dz2 = 2 * (z2 - y_true) / n # dL/dz2 + dW2 = dz2.reshape(-1, 1) @ a1.reshape(1, -1) # dL/dW2 + db2 = dz2 # dL/db2 + + da1 = dz2.reshape(1, -1) @ W2 # dL/da1 + da1 = da1.flatten() + dz1 = da1 * (z1 > 0).astype(float) # ReLU derivative + dW1 = dz1.reshape(-1, 1) @ x.reshape(1, -1) # dL/dW1 + db1 = dz1 # dL/db1 + + return { + 'loss': round(float(loss), 4), + 'dW1': np.round(dW1, 4).tolist(), + 'db1': np.round(db1, 4).tolist(), + 'dW2': np.round(dW2, 4).tolist(), + 'db2': np.round(db2, 4).tolist(), + } +``` +::tabs-end + + +### Walkthrough + +Given `x = [1.0, 2.0]`, `W1 = [[1, 0], [0, 1]]` (identity), `b1 = [0, 0]`, `W2 = [[0.5, 0.5]]`, `b2 = [0]`, `y_true = [1.0]`: + +| Step | Operation | Result | +|---|---|---| +| Layer 1 linear | $z_1 = x \cdot W_1^T + b_1$ | $[1.0, 2.0]$ | +| ReLU | $a_1 = \max(0, z_1)$ | $[1.0, 2.0]$ (all positive, all pass) | +| Layer 2 linear | $z_2 = a_1 \cdot W_2^T + b_2$ | $[1.5]$ | +| MSE loss | $(1.5 - 1.0)^2$ | $0.25$ | +| Output gradient | $\frac{2(1.5 - 1.0)}{1}$ | $dz_2 = [1.0]$ | +| Layer 2 weights | $[1.0] \cdot [1.0, 2.0]$ | $dW_2 = [[1.0, 2.0]]$ | +| Gradient to $a_1$ | $[1.0] \cdot W_2 = [0.5, 0.5]$ | Passes through ReLU (mask is all 1s) | +| Layer 1 weights | $[0.5, 0.5]^T \cdot [1.0, 2.0]$ | $dW_1 = [[0.5, 1.0], [0.5, 1.0]]$ | + +### Time & Space Complexity + +- Time: $O(d_1 \cdot d_2 + d_2 \cdot d_3)$ where $d_i$ are layer dimensions (matrix multiplications dominate) +- Space: $O(d_1 \cdot d_2 + d_2 \cdot d_3)$ for the gradient matrices + +--- + +## Common Pitfalls + +### Forgetting the ReLU Mask + +The ReLU derivative is not 1 everywhere. Where $z_1 \leq 0$, the gradient must be zeroed. Omitting this gives incorrect gradients for any neuron that was in the "dead zone." + +::tabs-start +```python +# Wrong: gradient flows through regardless of ReLU +dz1 = da1 # ignores the ReLU mask entirely + +# Correct: multiply by the ReLU indicator +dz1 = da1 * (z1 > 0).astype(float) +``` +::tabs-end + + +### Wrong Reshape for Outer Product + +The weight gradient $dW = \delta^T \cdot x$ requires the error signal as a column vector and the input as a row vector. Forgetting to reshape gives either a scalar (dot product) or an error. + +::tabs-start +```python +# Wrong: this computes a dot product (scalar), not a matrix +dW2 = dz2 @ a1 + +# Correct: outer product via reshape +dW2 = dz2.reshape(-1, 1) @ a1.reshape(1, -1) +``` +::tabs-end + + +--- + +## In the GPT Project + +This becomes `foundations/multi_layer_backprop.py`. Understanding multi-layer backprop is what makes the rest of the course click: when you call `loss.backward()` in PyTorch, it is doing exactly these chain-rule computations automatically through the entire transformer. The ReLU dead-zone issue you encounter here also appears in the feed-forward network inside each transformer block. + +--- + +## Key Takeaways + +- Multi-layer backpropagation is the same chain rule as single-neuron backprop, just applied to more links. Each layer's weight gradient is the outer product of the error signal arriving from above and the activation arriving from below. +- The ReLU derivative acts as a binary gate: gradients flow through neurons that fired ($z > 0$) and are killed for neurons that didn't. This is computationally cheap but creates the dead neuron problem. +- Saving intermediate values ($z_1$, $a_1$) during the forward pass is essential. You need them to compute gradients during the backward pass. This is why training uses more memory than inference. diff --git a/articles/transformer-block.md b/articles/transformer-block.md index 0836aa1ec..58b7a4a09 100644 --- a/articles/transformer-block.md +++ b/articles/transformer-block.md @@ -91,13 +91,14 @@ class TransformerBlock(nn.Module): self.att_heads = nn.ModuleList() for i in range(num_heads): self.att_heads.append(self.SingleHeadAttention(model_dim, model_dim // num_heads)) + self.output_proj = nn.Linear(model_dim, model_dim, bias=False) def forward(self, embedded: TensorType[float]) -> TensorType[float]: head_outputs = [] for head in self.att_heads: head_outputs.append(head(embedded)) concatenated = torch.cat(head_outputs, dim = 2) - return concatenated + return self.output_proj(concatenated) class VanillaNeuralNetwork(nn.Module): @@ -123,7 +124,7 @@ For `model_dim = 8` and `num_heads = 2`, with input shape $(B, T, 8)$: | Step | Operation | Shape | |---|---|---| | LayerNorm 1 | Normalize across dim 8 | $(B, T, 8)$ | -| Multi-Head Attention | 2 heads, each head_size=4, concat | $(B, T, 8)$ | +| Multi-Head Attention | 2 heads, each head_size=4, concat + $W^O$ | $(B, T, 8)$ | | Residual 1 | $x + \text{attention}(LN(x))$ | $(B, T, 8)$ | | LayerNorm 2 | Normalize the sum | $(B, T, 8)$ | | FFN up-project | Linear $8 \to 32$ + ReLU | $(B, T, 32)$ | diff --git a/articles/weight-initialization.md b/articles/weight-initialization.md new file mode 100644 index 000000000..b02ac6116 --- /dev/null +++ b/articles/weight-initialization.md @@ -0,0 +1,158 @@ +## Prerequisites + +Before attempting this problem, you should be comfortable with: + +- **Variance and standard deviation** - Understanding $\text{Var}(aX) = a^2 \text{Var}(X)$ is the core insight behind why initialization scales matter +- **Matrix multiplication** - Each layer multiplies its input by a weight matrix. If those weights have the wrong scale, the output variance changes every layer. +- **ReLU activation** - ReLU zeros out negative values, which cuts the variance roughly in half. Kaiming initialization compensates for this. + +--- + +## Concept + +Stack 10 layers with random $\mathcal{N}(0, 1)$ weights and the signal either explodes to infinity or collapses to zero. This is one of the most common reasons deep networks fail to train. + +The math is straightforward. If layer $\ell$ computes $z_\ell = W_\ell \cdot a_{\ell-1}$, then: + +$$\text{Var}(z_\ell) = n_{\ell-1} \cdot \text{Var}(W_\ell) \cdot \text{Var}(a_{\ell-1})$$ + +where $n_{\ell-1}$ is the number of input features (fan\_in). If $\text{Var}(W) = 1$ and fan\_in $= 64$, the variance multiplies by 64 every layer. After 10 layers: $64^{10} \approx 10^{18}$. The signal is gone. + +**Xavier initialization** (Glorot, 2010) sets the variance so each layer preserves its input's variance for sigmoid/tanh: + +$$\text{std} = \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}}$$ + +**Kaiming initialization** (He, 2015) adapts this for ReLU, which zeros out roughly half the distribution: + +$$\text{std} = \sqrt{\frac{2}{\text{fan\_in}}}$$ + +The factor of 2 in the numerator compensates for the halved variance from ReLU. With Kaiming init, activation standard deviations stay roughly constant regardless of depth. With random init, they grow exponentially. + +--- + +## Solution + +### Intuition + +Sample weights from $\mathcal{N}(0, \text{std}^2)$ where std is chosen to preserve variance across layers. Xavier averages fan\_in and fan\_out; Kaiming uses only fan\_in with a factor of 2 for ReLU. The `check_activations` function empirically validates this by creating weight matrices directly (as `torch.randn * std`), forwarding random data through linear + ReLU at each layer, and measuring the std. + +### Implementation + +::tabs-start +```python +import torch +import torch.nn as nn +import math + + +class Solution: + + def xavier_init(self, fan_in: int, fan_out: int) -> list[list[float]]: + torch.manual_seed(0) + std = math.sqrt(2.0 / (fan_in + fan_out)) + weights = torch.randn(fan_out, fan_in) * std + return torch.round(weights, decimals=4).tolist() + + def kaiming_init(self, fan_in: int, fan_out: int) -> list[list[float]]: + torch.manual_seed(0) + std = math.sqrt(2.0 / fan_in) + weights = torch.randn(fan_out, fan_in) * std + return torch.round(weights, decimals=4).tolist() + + def check_activations(self, num_layers: int, input_dim: int, hidden_dim: int, init_type: str) -> list[float]: + torch.manual_seed(0) + dims = [input_dim] + [hidden_dim] * num_layers + weights = [] + for i in range(num_layers): + if init_type == 'xavier': + std = math.sqrt(2.0 / (dims[i] + dims[i + 1])) + elif init_type == 'kaiming': + std = math.sqrt(2.0 / dims[i]) + else: + std = 1.0 + w = torch.randn(dims[i + 1], dims[i]) * std + weights.append(w) + + x = torch.randn(1, input_dim) + stds = [] + for w in weights: + x = x @ w.T + x = torch.relu(x) + stds.append(round(x.std().item(), 2)) + + return stds +``` +::tabs-end + + +### Walkthrough + +For `xavier_init(fan_in=4, fan_out=3)`: + +| Step | Computation | Result | +|---|---|---| +| Std formula | $\sqrt{2 / (4 + 3)} = \sqrt{2/7}$ | $0.5345$ | +| Sample | `torch.randn(3, 4)` with seed 0 | Standard normal matrix | +| Scale | Multiply by $0.5345$ | Each weight $\sim \mathcal{N}(0, 0.2857)$ | +| Round | 4 decimal places | `[[0.8237, -0.1568, -1.1646, 0.3038], ...]` | + +For `check_activations(5, 64, 64, 'random')` vs `'kaiming'`: + +| Init type | Layer 1 std | Layer 3 std | Layer 5 std | Trend | +|---|---|---|---|---| +| `random` | 3.97 | 166.15 | 4315.23 | Explodes exponentially | +| `kaiming` | 0.70 | 0.92 | 0.75 | Stays stable (~0.5-1.2) | + +### Time & Space Complexity + +- Time: $O(\text{fan\_in} \times \text{fan\_out})$ per weight matrix initialization +- Space: $O(\text{fan\_in} \times \text{fan\_out})$ for the weight matrix + +--- + +## Common Pitfalls + +### Using Xavier for ReLU Networks + +Xavier assumes a symmetric activation (sigmoid/tanh). ReLU kills half the distribution, so Xavier underestimates the needed variance and activations shrink toward zero in deep ReLU networks. + +::tabs-start +```python +# Wrong for ReLU: Xavier doesn't account for half the values being zeroed +std = math.sqrt(2.0 / (fan_in + fan_out)) + +# Correct for ReLU: Kaiming compensates with factor of 2/fan_in +std = math.sqrt(2.0 / fan_in) +``` +::tabs-end + + +### Forgetting torch.manual_seed + +Without setting the seed before each initialization, results are non-reproducible. The seed must be set immediately before `torch.randn` to get deterministic output. + +::tabs-start +```python +# Wrong: seed set once at module level, not before each call +# (subsequent calls get different random values) + +# Correct: set seed right before sampling +torch.manual_seed(0) +weights = torch.randn(fan_out, fan_in) * std +``` +::tabs-end + + +--- + +## In the GPT Project + +This becomes `foundations/weight_init.py`. Every `nn.Linear` layer in your GPT model uses initialization under the hood. PyTorch defaults to Kaiming for linear layers. The GPT-2 paper specifically mentions using scaled initialization ($1/\sqrt{N}$ where $N$ is the number of residual layers) to prevent the residual stream from growing too large in deep transformers. + +--- + +## Key Takeaways + +- Weight initialization controls whether signals survive through deep networks. Random $\mathcal{N}(0,1)$ weights cause exponential growth or decay of activations. +- Xavier works for sigmoid/tanh by balancing fan\_in and fan\_out. Kaiming works for ReLU by accounting for the halved variance from zeroing negative values. +- The `check_activations` experiment makes the theory concrete: with random init, stds grow to thousands in 5 layers. With Kaiming, they stay near 1.0 regardless of depth.