Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions scripts/nn/layers/gpt2_layer.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

# GPT-2 pre-LN transformer block.
#
# Adapted from bert_layer.dml (Maximilian Luz). Three changes:
# 1. Pre-LN ordering: LayerNorm is applied *before* each sublayer,
# not after the residual add.
# 2. Causal attention: uses attention::forward_causal so token i
# can only attend to positions 0..i.
# 3. No final LayerNorm inside the block — the caller applies ln_f
# once after all 12 blocks.
#
# The helper functions (linear_tensor_forward, layer_norm_forward) are
# identical to bert_layer.dml and reused here.

source("nn/layers/affine.dml") as affine
source("nn/layers/multi_attention.dml") as attention
source("nn/layers/batch_norm1d.dml") as batch_norm
source("nn/layers/gelu.dml") as gelu

linear_tensor_forward = function(matrix[double] X, matrix[double] W, matrix[double] b, int B, int C)
return (matrix[double] out) {
/*
* Helper: linear layer on tensor input of shape (A, B*C).
*/
A = nrow(X)
C_new = ncol(W)
out = affine::forward(matrix(X, rows=A*B, cols=C), W, b)
out = matrix(out, rows=A, cols=B*C_new)
}

layer_norm_forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta, double epsilon, int B, int C)
return (matrix[double] out, matrix[double] cache_mean, matrix[double] cache_var, matrix[double] cache_norm) {
/*
* Helper: layer norm via 1D batch norm on tensor input of shape (A, B*C).
*/
A = nrow(X)
batch_norm_input = t(matrix(X, rows=A*B, cols=C))
emas_mat = matrix(0, rows=1, cols=A*B)
[batch_norm_out, unused1, unused2, cache_mean, cache_var, cache_norm] = batch_norm::forward(
batch_norm_input, t(gamma), t(beta), "train", emas_mat, emas_mat, 0.0, epsilon)
out = matrix(t(batch_norm_out), rows=A, cols=B*C)
}

forward = function(matrix[double] states,
int H, int T, int d, int I,
matrix[double] W_Q, matrix[double] b_Q,
matrix[double] W_K, matrix[double] b_K,
matrix[double] W_V, matrix[double] b_V,
matrix[double] W_context, matrix[double] b_context,
matrix[double] W_intermediate, matrix[double] b_intermediate,
matrix[double] W_out, matrix[double] b_out,
double epsilon_ln,
matrix[double] gamma_ln1, matrix[double] beta_ln1,
matrix[double] gamma_ln2, matrix[double] beta_ln2)
return (matrix[double] out_states) {
/*
* Forward pass for one GPT-2 pre-LN transformer block.
*
* Pre-LN ordering (GPT-2):
* out = states + Attn(LN1(states))
* out = out + MLP(LN2(out))
*
* vs BERT post-LN:
* out = LN1(states + Attn(states))
* out = LN2(out + MLP(out))
*
* Inputs (B: Batch size, T: Sequence length, D = d*H, I: MLP inner dim):
* - states: Hidden states, of shape (B, T*D).
* - H: Head count.
* - T: Sequence length.
* - d: Per-head embedding dim (D/H).
* NOTE: multi_attention.dml calls this parameter "D" internally,
* but it is the per-head dim. Scaling is sqrt(d), not sqrt(D).
* - I: Intermediate (MLP) width, typically 4*D.
* - W_Q, b_Q, W_K, b_K, W_V, b_V: Q/K/V projection weights + biases.
* - W_context, b_context: Attention output projection.
* - W_intermediate, b_intermediate: MLP first layer (D -> I).
* - W_out, b_out: MLP second layer (I -> D).
* - epsilon_ln: LayerNorm epsilon (GPT-2 uses 1e-5).
* - gamma_ln1, beta_ln1: LayerNorm 1 params, shape (1, D).
* - gamma_ln2, beta_ln2: LayerNorm 2 params, shape (1, D).
*
* Outputs:
* - out_states: Output hidden states, of shape (B, T*D).
*/
# Full embedding dim. d is per-head; D is the full width.
D = d * H

# --- Attention sub-block (pre-LN) ---
# LN1 before attention
[ln1_out, cm1, cv1, cn1] = layer_norm_forward(states, gamma_ln1, beta_ln1, epsilon_ln, T, D)

# Q, K, V projections
Q = linear_tensor_forward(ln1_out, W_Q, b_Q, T, D)
K = linear_tensor_forward(ln1_out, W_K, b_K, T, D)
V = linear_tensor_forward(ln1_out, W_V, b_V, T, D)

# Causal multi-head self-attention (GPT-2: no dropout at inference)
[context, attn_probs, dropout_mask_attn] = attention::forward_causal(Q, K, V, H, T, d, 0.0)

# Attention output projection
attn_out = linear_tensor_forward(context, W_context, b_context, T, D)

# Residual 1: states + Attn(LN1(states))
out_states = states + attn_out

# --- MLP sub-block (pre-LN) ---
# LN2 before MLP
[ln2_out, cm2, cv2, cn2] = layer_norm_forward(out_states, gamma_ln2, beta_ln2, epsilon_ln, T, D)

# MLP: expand -> GELU -> contract
mlp_hidden = linear_tensor_forward(ln2_out, W_intermediate, b_intermediate, T, D)
mlp_hidden = gelu::forward(mlp_hidden)
mlp_out = linear_tensor_forward(mlp_hidden, W_out, b_out, T, I)

# Residual 2: out + MLP(LN2(out))
out_states = out_states + mlp_out
}
40 changes: 32 additions & 8 deletions scripts/nn/layers/multi_attention.dml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ source("nn/layers/dropout.dml") as dropout
source("scripts/nn/util.dml") as util


forward = function(matrix[double] Q, matrix[double] K,
matrix[double] V, int H, int T, int D, double dropout_p)
forward_internal = function(matrix[double] Q, matrix[double] K,
matrix[double] V, int H, int T, int D, double dropout_p,
boolean causal)
return (matrix[double] context, matrix[double] attention, matrix[double] dropout_mask) {
/*
* Computes the forward pass for a multi-head attention layer.
Expand All @@ -39,6 +40,7 @@ forward = function(matrix[double] Q, matrix[double] K,
* - T: Sequence length.
* - D: Embedding length of single query, value, key,
* - dropout_p: Dropout probability.
* - causal: If TRUE, apply causal mask to prevent attending to future tokens.
*
* Outputs:
* - context: Token context embeddings, of shape (B, T*H*D)
Expand All @@ -56,6 +58,12 @@ forward = function(matrix[double] Q, matrix[double] K,
dropout_mask = matrix(0, rows=B, cols=H*T*T)
context = matrix(0, rows=B, cols=H*T*D)
K_norm = K / sqrt(D)
causal_mask = matrix(0, rows=T, cols=T)
if (causal) {
# Mask future positions (j > i) before softmax.
causal_mask = upper.tri(target=matrix(1, rows=T, cols=T), diag=FALSE, values=TRUE)
causal_mask = log(1 - causal_mask)
}

# For loops for tensor operations
for (batch in 1:B) {
Expand All @@ -74,9 +82,12 @@ forward = function(matrix[double] Q, matrix[double] K,
V_h = matrix(V_b[head], rows=T, cols=D)

attention_scores = Q_h %*% t(K_norm_h) # Shape (T, T)

# TODO: Add support for attention mask here


# Causal mask: set future positions (j > i) to -inf so softmax zeros them out.
if (causal) {
attention_scores = attention_scores + causal_mask
}

# Column-wise softmax
attention_probs_h = softmax::forward(attention_scores)

Expand Down Expand Up @@ -104,10 +115,22 @@ forward = function(matrix[double] Q, matrix[double] K,
context = util::transpose_ABCD_to_ACBD(context, H, T)
}

forward = function(matrix[double] Q, matrix[double] K,
matrix[double] V, int H, int T, int D, double dropout_p)
return (matrix[double] context, matrix[double] attention, matrix[double] dropout_mask) {
[context, attention, dropout_mask] = forward_internal(Q, K, V, H, T, D, dropout_p, FALSE)
}

forward_causal = function(matrix[double] Q, matrix[double] K,
matrix[double] V, int H, int T, int D, double dropout_p)
return (matrix[double] context, matrix[double] attention, matrix[double] dropout_mask) {
[context, attention, dropout_mask] = forward_internal(Q, K, V, H, T, D, dropout_p, TRUE)
}


backward = function(matrix[double] dcontext,
matrix[double] dropout_mask, matrix[double] attention, matrix[double] Q,
matrix[double] K, matrix[double] V, int H, int T,
backward = function(matrix[double] dcontext,
matrix[double] dropout_mask, matrix[double] attention, matrix[double] Q,
matrix[double] K, matrix[double] V, int H, int T,
int D, double dropout_p)
return (matrix[double] dQ, matrix[double] dK, matrix[double] dV) {
/*
Expand All @@ -124,6 +147,7 @@ backward = function(matrix[double] dcontext,
* - T: Sequence length.
* - D: Embedding length of single query, value, key,
* - dropout_p: Dropout probability.
* - causal: If TRUE, apply causal mask (must match forward pass setting).
*
* Outputs:
* - dQ: Gradient w.r.t. input querys, of shape (B,T*H*D).
Expand Down
25 changes: 25 additions & 0 deletions scripts/staging/llm-native/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Generated model weights (regenerate with tools/convert_gpt2.py)
weights/

# HuggingFace cache (when set to a local path)
.hf_cache/

# Python
__pycache__/
*.pyc
*.pyo
*.egg-info/
.eggs/

# Virtual environments
.venv/
venv/
env/

# Pytest
.pytest_cache/

# IDE / OS
.idea/
.vscode/
.DS_Store
Loading