diff --git a/scripts/nn/layers/gpt2_layer.dml b/scripts/nn/layers/gpt2_layer.dml new file mode 100644 index 00000000000..8a75af022ae --- /dev/null +++ b/scripts/nn/layers/gpt2_layer.dml @@ -0,0 +1,138 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# GPT-2 pre-LN transformer block. +# +# Adapted from bert_layer.dml (Maximilian Luz). Three changes: +# 1. Pre-LN ordering: LayerNorm is applied *before* each sublayer, +# not after the residual add. +# 2. Causal attention: uses attention::forward_causal so token i +# can only attend to positions 0..i. +# 3. No final LayerNorm inside the block — the caller applies ln_f +# once after all 12 blocks. +# +# The helper functions (linear_tensor_forward, layer_norm_forward) are +# identical to bert_layer.dml and reused here. + +source("nn/layers/affine.dml") as affine +source("nn/layers/multi_attention.dml") as attention +source("nn/layers/batch_norm1d.dml") as batch_norm +source("nn/layers/gelu.dml") as gelu + +linear_tensor_forward = function(matrix[double] X, matrix[double] W, matrix[double] b, int B, int C) + return (matrix[double] out) { + /* + * Helper: linear layer on tensor input of shape (A, B*C). + */ + A = nrow(X) + C_new = ncol(W) + out = affine::forward(matrix(X, rows=A*B, cols=C), W, b) + out = matrix(out, rows=A, cols=B*C_new) +} + +layer_norm_forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta, double epsilon, int B, int C) + return (matrix[double] out, matrix[double] cache_mean, matrix[double] cache_var, matrix[double] cache_norm) { + /* + * Helper: layer norm via 1D batch norm on tensor input of shape (A, B*C). + */ + A = nrow(X) + batch_norm_input = t(matrix(X, rows=A*B, cols=C)) + emas_mat = matrix(0, rows=1, cols=A*B) + [batch_norm_out, unused1, unused2, cache_mean, cache_var, cache_norm] = batch_norm::forward( + batch_norm_input, t(gamma), t(beta), "train", emas_mat, emas_mat, 0.0, epsilon) + out = matrix(t(batch_norm_out), rows=A, cols=B*C) +} + +forward = function(matrix[double] states, + int H, int T, int d, int I, + matrix[double] W_Q, matrix[double] b_Q, + matrix[double] W_K, matrix[double] b_K, + matrix[double] W_V, matrix[double] b_V, + matrix[double] W_context, matrix[double] b_context, + matrix[double] W_intermediate, matrix[double] b_intermediate, + matrix[double] W_out, matrix[double] b_out, + double epsilon_ln, + matrix[double] gamma_ln1, matrix[double] beta_ln1, + matrix[double] gamma_ln2, matrix[double] beta_ln2) + return (matrix[double] out_states) { + /* + * Forward pass for one GPT-2 pre-LN transformer block. + * + * Pre-LN ordering (GPT-2): + * out = states + Attn(LN1(states)) + * out = out + MLP(LN2(out)) + * + * vs BERT post-LN: + * out = LN1(states + Attn(states)) + * out = LN2(out + MLP(out)) + * + * Inputs (B: Batch size, T: Sequence length, D = d*H, I: MLP inner dim): + * - states: Hidden states, of shape (B, T*D). + * - H: Head count. + * - T: Sequence length. + * - d: Per-head embedding dim (D/H). + * NOTE: multi_attention.dml calls this parameter "D" internally, + * but it is the per-head dim. Scaling is sqrt(d), not sqrt(D). + * - I: Intermediate (MLP) width, typically 4*D. + * - W_Q, b_Q, W_K, b_K, W_V, b_V: Q/K/V projection weights + biases. + * - W_context, b_context: Attention output projection. + * - W_intermediate, b_intermediate: MLP first layer (D -> I). + * - W_out, b_out: MLP second layer (I -> D). + * - epsilon_ln: LayerNorm epsilon (GPT-2 uses 1e-5). + * - gamma_ln1, beta_ln1: LayerNorm 1 params, shape (1, D). + * - gamma_ln2, beta_ln2: LayerNorm 2 params, shape (1, D). + * + * Outputs: + * - out_states: Output hidden states, of shape (B, T*D). + */ + # Full embedding dim. d is per-head; D is the full width. + D = d * H + + # --- Attention sub-block (pre-LN) --- + # LN1 before attention + [ln1_out, cm1, cv1, cn1] = layer_norm_forward(states, gamma_ln1, beta_ln1, epsilon_ln, T, D) + + # Q, K, V projections + Q = linear_tensor_forward(ln1_out, W_Q, b_Q, T, D) + K = linear_tensor_forward(ln1_out, W_K, b_K, T, D) + V = linear_tensor_forward(ln1_out, W_V, b_V, T, D) + + # Causal multi-head self-attention (GPT-2: no dropout at inference) + [context, attn_probs, dropout_mask_attn] = attention::forward_causal(Q, K, V, H, T, d, 0.0) + + # Attention output projection + attn_out = linear_tensor_forward(context, W_context, b_context, T, D) + + # Residual 1: states + Attn(LN1(states)) + out_states = states + attn_out + + # --- MLP sub-block (pre-LN) --- + # LN2 before MLP + [ln2_out, cm2, cv2, cn2] = layer_norm_forward(out_states, gamma_ln2, beta_ln2, epsilon_ln, T, D) + + # MLP: expand -> GELU -> contract + mlp_hidden = linear_tensor_forward(ln2_out, W_intermediate, b_intermediate, T, D) + mlp_hidden = gelu::forward(mlp_hidden) + mlp_out = linear_tensor_forward(mlp_hidden, W_out, b_out, T, I) + + # Residual 2: out + MLP(LN2(out)) + out_states = out_states + mlp_out +} diff --git a/scripts/nn/layers/multi_attention.dml b/scripts/nn/layers/multi_attention.dml index 7b863f34b5d..7c9c96e68ff 100644 --- a/scripts/nn/layers/multi_attention.dml +++ b/scripts/nn/layers/multi_attention.dml @@ -25,8 +25,9 @@ source("nn/layers/dropout.dml") as dropout source("scripts/nn/util.dml") as util -forward = function(matrix[double] Q, matrix[double] K, - matrix[double] V, int H, int T, int D, double dropout_p) +forward_internal = function(matrix[double] Q, matrix[double] K, + matrix[double] V, int H, int T, int D, double dropout_p, + boolean causal) return (matrix[double] context, matrix[double] attention, matrix[double] dropout_mask) { /* * Computes the forward pass for a multi-head attention layer. @@ -39,6 +40,7 @@ forward = function(matrix[double] Q, matrix[double] K, * - T: Sequence length. * - D: Embedding length of single query, value, key, * - dropout_p: Dropout probability. + * - causal: If TRUE, apply causal mask to prevent attending to future tokens. * * Outputs: * - context: Token context embeddings, of shape (B, T*H*D) @@ -56,6 +58,12 @@ forward = function(matrix[double] Q, matrix[double] K, dropout_mask = matrix(0, rows=B, cols=H*T*T) context = matrix(0, rows=B, cols=H*T*D) K_norm = K / sqrt(D) + causal_mask = matrix(0, rows=T, cols=T) + if (causal) { + # Mask future positions (j > i) before softmax. + causal_mask = upper.tri(target=matrix(1, rows=T, cols=T), diag=FALSE, values=TRUE) + causal_mask = log(1 - causal_mask) + } # For loops for tensor operations for (batch in 1:B) { @@ -74,9 +82,12 @@ forward = function(matrix[double] Q, matrix[double] K, V_h = matrix(V_b[head], rows=T, cols=D) attention_scores = Q_h %*% t(K_norm_h) # Shape (T, T) - - # TODO: Add support for attention mask here - + + # Causal mask: set future positions (j > i) to -inf so softmax zeros them out. + if (causal) { + attention_scores = attention_scores + causal_mask + } + # Column-wise softmax attention_probs_h = softmax::forward(attention_scores) @@ -104,10 +115,22 @@ forward = function(matrix[double] Q, matrix[double] K, context = util::transpose_ABCD_to_ACBD(context, H, T) } +forward = function(matrix[double] Q, matrix[double] K, + matrix[double] V, int H, int T, int D, double dropout_p) + return (matrix[double] context, matrix[double] attention, matrix[double] dropout_mask) { + [context, attention, dropout_mask] = forward_internal(Q, K, V, H, T, D, dropout_p, FALSE) +} + +forward_causal = function(matrix[double] Q, matrix[double] K, + matrix[double] V, int H, int T, int D, double dropout_p) + return (matrix[double] context, matrix[double] attention, matrix[double] dropout_mask) { + [context, attention, dropout_mask] = forward_internal(Q, K, V, H, T, D, dropout_p, TRUE) +} + -backward = function(matrix[double] dcontext, - matrix[double] dropout_mask, matrix[double] attention, matrix[double] Q, - matrix[double] K, matrix[double] V, int H, int T, +backward = function(matrix[double] dcontext, + matrix[double] dropout_mask, matrix[double] attention, matrix[double] Q, + matrix[double] K, matrix[double] V, int H, int T, int D, double dropout_p) return (matrix[double] dQ, matrix[double] dK, matrix[double] dV) { /* @@ -124,6 +147,7 @@ backward = function(matrix[double] dcontext, * - T: Sequence length. * - D: Embedding length of single query, value, key, * - dropout_p: Dropout probability. + * - causal: If TRUE, apply causal mask (must match forward pass setting). * * Outputs: * - dQ: Gradient w.r.t. input querys, of shape (B,T*H*D). diff --git a/scripts/staging/llm-native/.gitignore b/scripts/staging/llm-native/.gitignore new file mode 100644 index 00000000000..cb5008665cf --- /dev/null +++ b/scripts/staging/llm-native/.gitignore @@ -0,0 +1,25 @@ +# Generated model weights (regenerate with tools/convert_gpt2.py) +weights/ + +# HuggingFace cache (when set to a local path) +.hf_cache/ + +# Python +__pycache__/ +*.pyc +*.pyo +*.egg-info/ +.eggs/ + +# Virtual environments +.venv/ +venv/ +env/ + +# Pytest +.pytest_cache/ + +# IDE / OS +.idea/ +.vscode/ +.DS_Store diff --git a/scripts/staging/llm-native/README.md b/scripts/staging/llm-native/README.md new file mode 100644 index 00000000000..523a5c57213 --- /dev/null +++ b/scripts/staging/llm-native/README.md @@ -0,0 +1,204 @@ + + +# Native LLM inference in DML (work in progress) + +This directory contains tooling for running pre-trained transformer language +models natively inside SystemDS, using the existing `scripts/nn/layers/*.dml` +operators (affine, multi-head attention with optional causal mask, layer norm, +GELU, etc.). + +The first model targeted is **GPT-2 small (124M)**. + +## Layout + +``` +llm-native/ +├── tools/ +│ ├── convert_gpt2.py # HF GPT-2 -> SystemDS CSV + MTD + manifest.json +│ ├── pack_weights.py # per-layer CSVs -> stacked all_*.csv (for DML driver) +│ ├── np_oracle_gpt2.py # pure-NumPy reference forward (debugger) +│ └── compare_logits.py # three-way HF / oracle / DML parity check +├── dml/ +│ └── gpt2_inference.dml # native DML inference driver +├── tests/ +│ └── test_convert_gpt2.py +├── weights/ # generated; gitignored +├── requirements.txt +└── README.md +``` + +## Quick start + +```bash +cd scripts/staging/llm-native +python -m venv .venv && source .venv/bin/activate +pip install -r requirements.txt + +# Convert the HF GPT-2 small checkpoint into DML-ready matrices. +python tools/convert_gpt2.py --model gpt2 --out weights/gpt2 + +# Pack the per-layer matrices into stacked all_*.csv files for DML. +python tools/pack_weights.py --weights weights/gpt2 + +# (Optional) Cross-check the converter against HuggingFace via the +# pure-NumPy reference forward. All per-step diffs should be < 1e-11. +python tools/np_oracle_gpt2.py --weights weights/gpt2 --compare-hf + +# Run native DML inference (writes logits.csv + per-block dumps). +echo -e "464\n2068\n7586\n21831\n625\n262\n16931\n3290\n13" > weights/gpt2/tokens.csv +SYSTEMDS_ROOT=$PWD/../../.. $SYSTEMDS_ROOT/bin/systemds dml/gpt2_inference.dml \ + -nvargs weights=weights/gpt2 \ + tokens=weights/gpt2/tokens.csv \ + out=weights/gpt2/dml_dumps \ + dump=TRUE + +# End-to-end three-way parity check (HF + NumPy oracle + DML driver): +python tools/compare_logits.py --with-dml "Hello, my name is" +``` + +`compare_logits.py` is the canonical artifact for "does native DML GPT-2 +match HuggingFace?". Default mode runs only HF + oracle (~1s, no DML); pass +`--with-dml` for the full three-way check (~10s; requires the converted +weights *and* a successful `pack_weights.py` run). All measured per-step +max-abs-diffs at gpt2 (124M) sit at ~1e-12, the float64 round-off floor. + +The converter is a one-shot Python script. After it runs, the `weights/gpt2/` +directory contains one `.csv` plus matching `.csv.mtd` per +parameter tensor, plus a `manifest.json` describing the model config, file +map, tied weights, and SHA-256 hashes. + +## Why a converter? + +The DML transformer layers in `scripts/nn/layers/` describe **computation only** +(matmuls, layer norm, attention). Trained parameter values live in HuggingFace +PyTorch checkpoints with HF-specific names and shapes. The converter is the +one-time bridge: + +1. Loads the HF `GPT2LMHeadModel` weights. +2. Splits the fused `c_attn` projection back into separate `W_Q`, `W_K`, `W_V`. +3. Reshapes biases from `(D,)` to `(1, D)` to match DML conventions. +4. Upcasts to `float64` (DML's native value type). +5. Writes one CSV + MTD pair per matrix and a `manifest.json` index. + +After conversion, DML scripts can `read("weights/gpt2/h0_W_Q.csv", format="csv")` +and feed the matrices directly to `bert_layer::forward(...)` / +`multi_attention::forward_causal(...)`. + +## HF -> DML name and shape mapping + +`B`/`T`/`D`/`H`/`I`/`V` follow the conventions of `bert_layer.dml` +(`D = n_embd = 768`, `I = 4*D = 3072`, `V = 50257` for GPT-2 small). + +| HF state-dict key | Shape (HF) | DML file (this dir) | DML role | +|--------------------------------|----------------|----------------------------|----------------------| +| `wte.weight` | `(V, D)` | `wte.csv` | token embedding | +| `wpe.weight` | `(n_ctx, D)` | `wpe.csv` | positional embedding | +| `h.i.ln_1.weight` | `(D,)` | `hi_ln1_gamma.csv` | LN1 gamma | +| `h.i.ln_1.bias` | `(D,)` | `hi_ln1_beta.csv` | LN1 beta | +| `h.i.attn.c_attn.weight[:,0:D]`| `(D, D)` | `hi_W_Q.csv` | query projection W | +| `h.i.attn.c_attn.weight[:,D:2D]`|`(D, D)` | `hi_W_K.csv` | key projection W | +| `h.i.attn.c_attn.weight[:,2D:3D]`|`(D, D)` | `hi_W_V.csv` | value projection W | +| `h.i.attn.c_attn.bias[0:D]` | `(D,)` | `hi_b_Q.csv` (1xD) | query bias | +| `h.i.attn.c_attn.bias[D:2D]` | `(D,)` | `hi_b_K.csv` | key bias | +| `h.i.attn.c_attn.bias[2D:3D]` | `(D,)` | `hi_b_V.csv` | value bias | +| `h.i.attn.c_proj.weight` | `(D, D)` | `hi_W_context.csv` | attn out W | +| `h.i.attn.c_proj.bias` | `(D,)` | `hi_b_context.csv` | attn out bias | +| `h.i.ln_2.weight` | `(D,)` | `hi_ln2_gamma.csv` | LN2 gamma | +| `h.i.ln_2.bias` | `(D,)` | `hi_ln2_beta.csv` | LN2 beta | +| `h.i.mlp.c_fc.weight` | `(D, I)` | `hi_W_intermediate.csv` | MLP expand W | +| `h.i.mlp.c_fc.bias` | `(I,)` | `hi_b_intermediate.csv` | MLP expand bias | +| `h.i.mlp.c_proj.weight` | `(I, D)` | `hi_W_out.csv` | MLP contract W | +| `h.i.mlp.c_proj.bias` | `(D,)` | `hi_b_out.csv` | MLP contract bias | +| `ln_f.weight` | `(D,)` | `lnf_gamma.csv` | final LN gamma | +| `ln_f.bias` | `(D,)` | `lnf_beta.csv` | final LN beta | +| `lm_head.weight` | tied to `wte` | (none) | recorded in manifest | + +GPT-2 uses HuggingFace's `Conv1D` linear layer, which stores weights as +`(in, out)` -- exactly what the DML affine layer (`W : (D, M)`) expects, so +no transpose is performed during conversion. + +## Manifest format + +`manifest.json` is the index DML drivers should consult first: + +```json +{ + "model": "gpt2", + "arch": "gpt2-causal", + "config": { + "n_layer": 12, "n_head": 12, "n_embd": 768, + "n_ctx": 1024, "vocab_size": 50257, + "activation": "gelu", "layer_norm_eps": 1.0e-5 + }, + "dtype": "float64", + "tied": { "lm_head": "wte" }, + "files": { "wte": "wte.csv", "...": "..." }, + "sha256": { "wte.csv": "ab12...", "...": "..." } +} +``` + +`tied.lm_head = wte` means the DML driver should reuse `wte` (transposed) for +the language-modeling head rather than expecting a separate file. + +## CLI + +``` +python tools/convert_gpt2.py [options] + + --model HF model id or local path (default: gpt2) + --out output directory (default: weights/) + --dtype {float64,float32} (default: float64) + --cache HuggingFace cache directory (default: HF default) +``` + +## Tests + +```bash +pytest scripts/staging/llm-native/tests +``` + +The test suite uses `sshleifer/tiny-gpt2` (a 5-layer 64-dim fixture, a few MB) +to verify the converter end-to-end without downloading the full GPT-2 weights. + +## Implementation notes (gotchas worth knowing) + +A few SystemDS / Hadoop quirks that shaped the design here, recorded so the +next person who tries this doesn't lose a day to them: + +1. **DML `read()` paths must be const-string-traceable.** The DML parser + rejects `read()` calls whose filename argument is built from runtime + variables -- including loop counters and `ifdef` defaults that aren't + string literals. Trying to `read("weights/h" + i + "_W_Q.csv", ...)` in a + loop fails at parse time with a `NullPointerException` deep inside + `StringIdentifier.getValue()`, which is a deeply unhelpful error. + + Workaround used here: `pack_weights.py` `vstack`s the 12 per-layer copies + of each parameter into a single `all_.csv` (e.g. `all_W_Q.csv` of + shape `(12*D, D)`). The DML driver `read()`s 16 stacked files once, then + row-slices inside the per-layer loop -- no runtime-built paths needed. + This also turns 196 disk reads into 16, which is why startup is bearable + on a laptop. + +2. **Hadoop's `FileInputFormat` silently skips files whose names start with + `_` or `.`** (its hidden / partition-marker convention). SystemDS uses + the Hadoop input layer for CSV reads, so a tokens file called + `_tokens.csv` will produce an `InvalidInputException("Input path does not + exist")` at runtime even though `ls` shows it sitting right there. Name + inputs without a leading underscore. diff --git a/scripts/staging/llm-native/dml/gpt2_inference.dml b/scripts/staging/llm-native/dml/gpt2_inference.dml new file mode 100644 index 00000000000..3802ba0599a --- /dev/null +++ b/scripts/staging/llm-native/dml/gpt2_inference.dml @@ -0,0 +1,192 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# DML inference driver for GPT-2 small (124M). +# +# Reads stacked CSV+MTD weights produced by tools/convert_gpt2.py followed +# by tools/pack_weights.py, runs a single forward pass over a token-id +# sequence, and writes the logits (and optionally per-block hidden-state +# dumps) to disk. The dumps line up byte-for-byte with the keys produced +# by tools/np_oracle_gpt2.py, so the comparison harness can localize any +# divergence to a specific transformer block. +# +# Why stacked weights? SystemDS' DML parser requires ``read()`` filenames +# to be const-string-traceable, which excludes loop-variable concatenation. +# pack_weights.py therefore vstacks the per-layer matrices into one +# ``all_.csv`` per parameter type; the loop body row-slices them. +# +# Usage: +# python tools/pack_weights.py --weights weights/gpt2 +# systemds scripts/staging/llm-native/dml/gpt2_inference.dml \ +# -nvargs weights=scripts/staging/llm-native/weights/gpt2 \ +# tokens=scripts/staging/llm-native/weights/gpt2/_tokens.csv \ +# out=scripts/staging/llm-native/weights/gpt2/_dml_dumps \ +# dump=TRUE +# +# Hardcoded for gpt2-small. Larger variants only need n_layer / n_head / +# D / I changed; the loop and reads are otherwise shape-agnostic. + +source("nn/layers/gpt2_layer.dml") as gpt2_layer + +# --------------------------------------------------------------------------- +# Args (literal-default ifdefs only -- DML's read() requires const paths) +# --------------------------------------------------------------------------- + +weights = ifdef($weights, "scripts/staging/llm-native/weights/gpt2") +tokens_path = ifdef($tokens, "scripts/staging/llm-native/weights/gpt2/_tokens.csv") +out_dir = ifdef($out, "scripts/staging/llm-native/weights/gpt2/_dml_dumps") +dump = ifdef($dump, TRUE) + +# --------------------------------------------------------------------------- +# Model config (gpt2-small) +# --------------------------------------------------------------------------- + +n_layer = 12 +n_head = 12 +D = 768 +d = 64 # D / n_head +I = 4 * D # 3072 +V = 50257 +n_ctx = 1024 +eps = 1.0e-5 + +# --------------------------------------------------------------------------- +# Inputs +# --------------------------------------------------------------------------- + +# tokens: column vector of HF (0-indexed) token ids, shape (T, 1) +tokens = read(tokens_path, format="csv") +T = nrow(tokens) +if (T > n_ctx) { + print("ERROR: sequence length " + T + " exceeds n_ctx " + n_ctx) + stop("sequence too long") +} + +print("[gpt2_inference] T=" + T + " D=" + D + " n_layer=" + n_layer + + " weights=" + weights) + +# --------------------------------------------------------------------------- +# Embedding lookup: wte[ids] + wpe[0:T] +# --------------------------------------------------------------------------- + +wte = read(weights + "/wte.csv", format="csv") # (V, D) +wpe = read(weights + "/wpe.csv", format="csv") # (n_ctx, D) + +# DML is 1-indexed; HF token ids are 0-indexed -> shift by 1. +ids_1based = tokens + 1 +row_ids = seq(1, T, 1) +onehot = table(row_ids, ids_1based, T, V) # (T, V) +tok_emb = onehot %*% wte # (T, D) +pos_emb = wpe[1:T,] # (T, D) +embed = tok_emb + pos_emb # (T, D) + +if (dump) { write(embed, out_dir + "/embed.csv", format="csv") } + +# Carry hidden states in (B=1, T*D) layout to match the block's signature. +states = matrix(embed, rows=1, cols=T*D) + +# --------------------------------------------------------------------------- +# Stacked weights (vstack of n_layer per-layer tensors, written by pack_weights.py) +# --------------------------------------------------------------------------- + +all_ln1_gamma = read(weights + "/all_ln1_gamma.csv", format="csv") # (n_layer, D) +all_ln1_beta = read(weights + "/all_ln1_beta.csv", format="csv") # (n_layer, D) +all_W_Q = read(weights + "/all_W_Q.csv", format="csv") # (n_layer*D, D) +all_b_Q = read(weights + "/all_b_Q.csv", format="csv") # (n_layer, D) +all_W_K = read(weights + "/all_W_K.csv", format="csv") # (n_layer*D, D) +all_b_K = read(weights + "/all_b_K.csv", format="csv") # (n_layer, D) +all_W_V = read(weights + "/all_W_V.csv", format="csv") # (n_layer*D, D) +all_b_V = read(weights + "/all_b_V.csv", format="csv") # (n_layer, D) +all_W_context = read(weights + "/all_W_context.csv", format="csv") # (n_layer*D, D) +all_b_context = read(weights + "/all_b_context.csv", format="csv") # (n_layer, D) +all_ln2_gamma = read(weights + "/all_ln2_gamma.csv", format="csv") # (n_layer, D) +all_ln2_beta = read(weights + "/all_ln2_beta.csv", format="csv") # (n_layer, D) +all_W_intermediate = read(weights + "/all_W_intermediate.csv", format="csv") # (n_layer*D, I) +all_b_intermediate = read(weights + "/all_b_intermediate.csv", format="csv") # (n_layer, I) +all_W_out = read(weights + "/all_W_out.csv", format="csv") # (n_layer*I, D) +all_b_out = read(weights + "/all_b_out.csv", format="csv") # (n_layer, D) + +# --------------------------------------------------------------------------- +# Transformer blocks +# --------------------------------------------------------------------------- + +for (i in 0:n_layer-1) { + # Row slices into the stacked tensors for layer i (DML is 1-indexed). + rD0 = i * D + 1 + rD1 = (i + 1) * D + rI0 = i * I + 1 + rI1 = (i + 1) * I + rB = i + 1 + + W_Q = all_W_Q[rD0:rD1,] # (D, D) + W_K = all_W_K[rD0:rD1,] # (D, D) + W_V = all_W_V[rD0:rD1,] # (D, D) + W_context = all_W_context[rD0:rD1,] # (D, D) + W_intermediate = all_W_intermediate[rD0:rD1,] # (D, I) + W_out = all_W_out[rI0:rI1,] # (I, D) + + b_Q = all_b_Q[rB:rB,] # (1, D) + b_K = all_b_K[rB:rB,] # (1, D) + b_V = all_b_V[rB:rB,] # (1, D) + b_context = all_b_context[rB:rB,] # (1, D) + b_intermediate = all_b_intermediate[rB:rB,] # (1, I) + b_out = all_b_out[rB:rB,] # (1, D) + ln1_g = all_ln1_gamma[rB:rB,] # (1, D) + ln1_b = all_ln1_beta[rB:rB,] # (1, D) + ln2_g = all_ln2_gamma[rB:rB,] # (1, D) + ln2_b = all_ln2_beta[rB:rB,] # (1, D) + + states = gpt2_layer::forward(states, + n_head, T, d, I, + W_Q, b_Q, W_K, b_K, W_V, b_V, + W_context, b_context, + W_intermediate, b_intermediate, + W_out, b_out, + eps, + ln1_g, ln1_b, + ln2_g, ln2_b) + + if (dump) { + block_out = matrix(states, rows=T, cols=D) + write(block_out, out_dir + "/h" + i + "_out.csv", format="csv") + } +} + +# --------------------------------------------------------------------------- +# Final LayerNorm + tied LM head +# --------------------------------------------------------------------------- + +lnf_gamma = read(weights + "/lnf_gamma.csv", format="csv") # (1, D) +lnf_beta = read(weights + "/lnf_beta.csv", format="csv") # (1, D) + +# layer_norm_forward expects (A, B*C) with A=batch, B=T, C=D; states is (1, T*D). +[lnf_flat, cm, cv, cn] = gpt2_layer::layer_norm_forward( + states, lnf_gamma, lnf_beta, eps, T, D) +lnf_out = matrix(lnf_flat, rows=T, cols=D) # (T, D) + +if (dump) { write(lnf_out, out_dir + "/lnf.csv", format="csv") } + +# Tied LM head: logits = lnf_out %*% t(wte), shape (T, V). +logits = lnf_out %*% t(wte) +write(logits, out_dir + "/logits.csv", format="csv") + +print("[gpt2_inference] wrote logits (" + nrow(logits) + ", " + ncol(logits) + ") to " + + out_dir + "/logits.csv") diff --git a/scripts/staging/llm-native/requirements.txt b/scripts/staging/llm-native/requirements.txt new file mode 100644 index 00000000000..94572f9812f --- /dev/null +++ b/scripts/staging/llm-native/requirements.txt @@ -0,0 +1,4 @@ +numpy>=1.24 +torch>=2.1 +transformers>=4.40 +pytest>=8.0 diff --git a/scripts/staging/llm-native/tests/test_convert_gpt2.py b/scripts/staging/llm-native/tests/test_convert_gpt2.py new file mode 100644 index 00000000000..a85e6da3a80 --- /dev/null +++ b/scripts/staging/llm-native/tests/test_convert_gpt2.py @@ -0,0 +1,167 @@ +# ------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# ------------------------------------------------------------- + +"""Tests for tools/convert_gpt2.py. + +Uses `sshleifer/tiny-gpt2` (a few MB, 2 layers, 32-dim embedding) so the +suite runs end-to-end without downloading the full GPT-2 weights. + +Run with: + + pytest scripts/staging/llm-native/tests +""" + +from __future__ import annotations + +import json +import os +import sys + +import numpy as np +import pytest + + +HERE = os.path.dirname(os.path.abspath(__file__)) +TOOLS = os.path.normpath(os.path.join(HERE, os.pardir, "tools")) +sys.path.insert(0, TOOLS) + +# Importing transformers/torch is expensive; import lazily inside fixtures. +TINY = "sshleifer/tiny-gpt2" + + +@pytest.fixture(scope="module") +def converted(tmp_path_factory): + pytest.importorskip("transformers") + pytest.importorskip("torch") + from convert_gpt2 import convert # noqa: WPS433 (sys.path manipulated above) + + out = tmp_path_factory.mktemp("tiny_gpt2_weights") + manifest = convert(model_id=TINY, out_dir=str(out), dtype="float64") + return out, manifest + + +def _load_csv(path): + return np.loadtxt(path, delimiter=",", ndmin=2) + + +def test_manifest_top_level(converted): + out, manifest = converted + assert manifest["arch"] == "gpt2-causal" + assert manifest["dtype"] == "float64" + assert manifest["tied"] == {"lm_head": "wte"} + assert manifest["model"] == TINY + cfg = manifest["config"] + assert cfg["n_layer"] >= 1 + assert cfg["n_embd"] >= 1 + assert cfg["vocab_size"] > 0 + assert (out / "manifest.json").exists() + on_disk = json.loads((out / "manifest.json").read_text()) + assert on_disk == manifest + + +def test_all_listed_files_exist(converted): + out, manifest = converted + for name, rel in manifest["files"].items(): + assert (out / rel).exists(), f"missing CSV for {name}" + assert (out / (rel + ".mtd")).exists(), f"missing MTD for {name}" + + +def test_required_keys_per_block(converted): + _, manifest = converted + n = manifest["config"]["n_layer"] + files = manifest["files"] + assert "wte" in files and "wpe" in files + assert "lnf_gamma" in files and "lnf_beta" in files + expected = ( + "ln1_gamma ln1_beta " + "W_Q W_K W_V b_Q b_K b_V " + "W_context b_context " + "ln2_gamma ln2_beta " + "W_intermediate b_intermediate W_out b_out" + ).split() + for i in range(n): + for suffix in expected: + key = f"h{i}_{suffix}" + assert key in files, f"missing manifest key {key}" + + +def test_qkv_split_shapes(converted): + out, manifest = converted + cfg = manifest["config"] + D = cfg["n_embd"] + for i in range(cfg["n_layer"]): + for k in ("W_Q", "W_K", "W_V"): + shape = manifest["shapes"][f"h{i}_{k}"] + assert shape == [D, D], f"h{i}_{k}: got {shape}, want [{D},{D}]" + for k in ("b_Q", "b_K", "b_V"): + shape = manifest["shapes"][f"h{i}_{k}"] + assert shape == [1, D], f"h{i}_{k}: got {shape}, want [1,{D}]" + + +def test_qkv_split_values_match_huggingface(converted): + """Concatenating our split Q|K|V back together must equal HF's c_attn.""" + pytest.importorskip("transformers") + from transformers import GPT2LMHeadModel + + out, manifest = converted + cfg = manifest["config"] + D = cfg["n_embd"] + + sd = GPT2LMHeadModel.from_pretrained(TINY).transformer.state_dict() + + for i in range(cfg["n_layer"]): + Wc = sd[f"h.{i}.attn.c_attn.weight"].detach().cpu().numpy().astype(np.float64) + bc = sd[f"h.{i}.attn.c_attn.bias"].detach().cpu().numpy().astype(np.float64) + + W_Q = _load_csv(out / manifest["files"][f"h{i}_W_Q"]) + W_K = _load_csv(out / manifest["files"][f"h{i}_W_K"]) + W_V = _load_csv(out / manifest["files"][f"h{i}_W_V"]) + b_Q = _load_csv(out / manifest["files"][f"h{i}_b_Q"]).reshape(-1) + b_K = _load_csv(out / manifest["files"][f"h{i}_b_K"]).reshape(-1) + b_V = _load_csv(out / manifest["files"][f"h{i}_b_V"]).reshape(-1) + + recon_W = np.concatenate([W_Q, W_K, W_V], axis=1) + recon_b = np.concatenate([b_Q, b_K, b_V]) + + np.testing.assert_allclose(recon_W, Wc, rtol=0, atol=0, + err_msg=f"layer {i} W mismatch") + np.testing.assert_allclose(recon_b, bc, rtol=0, atol=0, + err_msg=f"layer {i} b mismatch") + assert W_Q.shape == (D, D) + + +def test_mtd_metadata_well_formed(converted): + out, manifest = converted + for name, rel in manifest["files"].items(): + mtd = json.loads((out / (rel + ".mtd")).read_text()) + assert mtd["data_type"] == "matrix" + assert mtd["value_type"] == "double" + assert mtd["format"] == "csv" + assert mtd["header"] is False + assert mtd["rows"] == manifest["shapes"][name][0] + assert mtd["cols"] == manifest["shapes"][name][1] + + +def test_no_lm_head_file(converted): + """lm_head must be tied, not duplicated.""" + _, manifest = converted + assert "lm_head" not in manifest["files"] + assert manifest["tied"]["lm_head"] == "wte" diff --git a/scripts/staging/llm-native/tools/compare_logits.py b/scripts/staging/llm-native/tools/compare_logits.py new file mode 100644 index 00000000000..61acf99e347 --- /dev/null +++ b/scripts/staging/llm-native/tools/compare_logits.py @@ -0,0 +1,386 @@ +# ------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# ------------------------------------------------------------- + +"""Three-way correctness check: HuggingFace ↔ NumPy oracle ↔ DML driver. + +This is the single runnable artifact that demonstrates per-layer numerical +parity between the three implementations of GPT-2 forward in this project: + + HuggingFace (PyTorch, float64) + │ reference; ground truth in the field + ▼ + NumPy oracle (tools/np_oracle_gpt2.py) + │ reads converter CSVs; pure NumPy; no DML in the loop + ▼ + DML driver (dml/gpt2_inference.dml) + │ reads stacked CSVs; SystemDS native execution + +Default mode is *lightweight*: tokenize, run HF + oracle, print per-step +max-abs-diff. Use ``--with-dml`` to additionally run the DML driver via +subprocess and add the Oracle↔DML and HF↔DML columns; this takes ~10s on +gpt2-small instead of <1s. + +Usage examples +-------------- + + # Quick sanity check (HF vs oracle only): + python tools/compare_logits.py "Hello, my name is" + + # Full three-way check (also runs DML): + python tools/compare_logits.py --with-dml "Hello, my name is" + + # Longer-prompt sanity (does parity hold at T=128?): + python tools/compare_logits.py --with-dml --tmax 128 "$(cat long_prompt.txt)" + + # Machine-readable JSON for CI: + python tools/compare_logits.py --with-dml --json "Hello, my name is" + +Exit code is 0 iff every measured max-abs-diff is ≤ ``--tolerance`` +(default 1e-9, comfortably above float64 round-off). +""" + +from __future__ import annotations + +import argparse +import json +import os +import shutil +import subprocess +import sys +import tempfile +import time +from pathlib import Path +from typing import Sequence + +import numpy as np + +# Reuse the oracle's pure-NumPy forward; keeps math centralized. +sys.path.insert(0, str(Path(__file__).resolve().parent)) +from np_oracle_gpt2 import forward as oracle_forward # noqa: E402 + + +# Keys whose dumps the DML driver writes; everything else (h{i}_ln1, h{i}_attn, +# h{i}_ln2) is internal to gpt2_layer::forward and only the oracle dumps it. +_DML_KEYS = ["embed"] + [f"h{i}_out" for i in range(12)] + ["lnf", "logits"] + + +# --------------------------------------------------------------------------- +# Tokenization +# --------------------------------------------------------------------------- + +def _tokenize(prompt: str, model_id: str, tmax: int | None) -> list[int]: + try: + from transformers import GPT2TokenizerFast + except ImportError as e: + raise SystemExit( + "ERROR: transformers required. pip install -r requirements.txt" + ) from e + tok = GPT2TokenizerFast.from_pretrained(model_id) + ids = tok.encode(prompt) + if tmax is not None and len(ids) > tmax: + ids = ids[:tmax] + if len(ids) == 0: + raise SystemExit("ERROR: empty prompt yields zero tokens") + return ids + + +# --------------------------------------------------------------------------- +# Reference runs +# --------------------------------------------------------------------------- + +def _run_hf(token_ids: Sequence[int], model_id: str) -> tuple[dict[str, np.ndarray], float]: + """HF forward in float64; returns dumps in oracle's key convention.""" + try: + import torch + from transformers import GPT2LMHeadModel + except ImportError as e: + raise SystemExit("ERROR: torch + transformers required") from e + + t0 = time.time() + ids = torch.tensor([list(token_ids)], dtype=torch.long) + model = GPT2LMHeadModel.from_pretrained(model_id).double().eval() + with torch.no_grad(): + out = model(ids, output_hidden_states=True) + elapsed = time.time() - t0 + + hf_hidden = [h[0].numpy().astype(np.float64) for h in out.hidden_states] + n_layer = len(hf_hidden) - 1 + + # HF mapping (see np_oracle_gpt2.compare_with_hf for derivation): + # hidden_states[0] == post-embed (oracle 'embed') + # hidden_states[1..n_layer-1] == output of block i (oracle 'h{i}_out') + # hidden_states[n_layer] == post-final-LN (oracle 'lnf') + dumps: dict[str, np.ndarray] = {"embed": hf_hidden[0]} + for i in range(n_layer - 1): + dumps[f"h{i}_out"] = hf_hidden[i + 1] + dumps["lnf"] = hf_hidden[n_layer] + dumps["logits"] = out.logits[0].numpy().astype(np.float64) + # HF doesn't expose the pre-ln_f output of the last block, so 'h{n_layer-1}_out' + # is intentionally absent from the HF dumps. + return dumps, elapsed + + +def _run_oracle(token_ids: Sequence[int], weights_dir: Path) -> tuple[dict[str, np.ndarray], float]: + t0 = time.time() + states = oracle_forward(token_ids, weights_dir, dump=True) + return states, time.time() - t0 + + +def _run_dml( + token_ids: Sequence[int], + weights_dir: Path, + systemds_bin: Path, + systemds_root: Path, + keep_tmp: bool = False, +) -> tuple[dict[str, np.ndarray], float, Path]: + """Invoke the DML driver via subprocess, return its on-disk dumps.""" + if not systemds_bin.exists(): + raise SystemExit(f"ERROR: systemds binary not found at {systemds_bin}") + + tmp = Path(tempfile.mkdtemp(prefix="gpt2_compare_")) + tokens_path = tmp / "tokens.csv" + dump_dir = tmp / "dml_dumps" + + # Write tokens (one id per line) plus the matching .mtd so SystemDS picks + # up the (T, 1) double-matrix shape without inference. + # Filename intentionally does not start with '_' or '.': Hadoop's + # FileInputFormat (used for CSV reads) silently skips files matching + # those prefixes as hidden / partition markers, surfacing only as a + # cryptic "Input path does not exist" at DML runtime. + tokens_path.write_text("\n".join(str(int(i)) for i in token_ids) + "\n") + mtd = { + "data_type": "matrix", "value_type": "double", "format": "csv", + "header": False, "sep": ",", + "rows": len(token_ids), "cols": 1, + } + (tokens_path.with_suffix(".csv.mtd")).write_text(json.dumps(mtd)) + + driver = systemds_root / "scripts/staging/llm-native/dml/gpt2_inference.dml" + + env = dict(os.environ) + env["SYSTEMDS_ROOT"] = str(systemds_root) + env["SYSDS_QUIET"] = "1" + + cmd = [ + str(systemds_bin), str(driver), + "-nvargs", + f"weights={weights_dir}", + f"tokens={tokens_path}", + f"out={dump_dir}", + "dump=TRUE", + ] + + # cwd must be the repo root: the driver source()'s nn/layers/* via paths + # that resolve against the SystemDS jar's bundled scripts/ tree, which + # only takes precedence when launched from there. + t0 = time.time() + proc = subprocess.run(cmd, env=env, cwd=str(systemds_root), + capture_output=True, text=True) + elapsed = time.time() - t0 + + if proc.returncode != 0: + sys.stderr.write(proc.stdout) + sys.stderr.write(proc.stderr) + raise SystemExit(f"ERROR: DML driver exited {proc.returncode}") + + dumps: dict[str, np.ndarray] = {} + for k in _DML_KEYS: + path = dump_dir / f"{k}.csv" + if not path.exists(): + sys.stderr.write("[compare] --- DML stdout ---\n") + sys.stderr.write(proc.stdout) + sys.stderr.write("[compare] --- DML stderr ---\n") + sys.stderr.write(proc.stderr) + raise SystemExit(f"ERROR: DML driver did not produce {path}") + dumps[k] = np.loadtxt(path, delimiter=",", dtype=np.float64, ndmin=2) + + if not keep_tmp: + shutil.rmtree(tmp, ignore_errors=True) + return dumps, elapsed, tmp + + +# --------------------------------------------------------------------------- +# Comparison +# --------------------------------------------------------------------------- + +def _max_abs_diff(a: np.ndarray, b: np.ndarray) -> float: + if a.shape != b.shape: + raise ValueError(f"shape mismatch: {a.shape} vs {b.shape}") + return float(np.abs(a - b).max()) + + +def _build_report( + hf: dict[str, np.ndarray], + oracle: dict[str, np.ndarray], + dml: dict[str, np.ndarray] | None, +) -> dict: + """Return per-step diffs across whatever pairs are available.""" + rows: list[dict] = [] + # Oracle has the most keys (h{i}_ln1/attn/ln2 too); we only report on + # rows that exist in *every* available source so the table is rectangular. + base_keys = ["embed"] + [f"h{i}_out" for i in range(12)] + ["lnf", "logits"] + + for k in base_keys: + row: dict = {"key": k, "shape": list(oracle[k].shape) if k in oracle else None} + if k in hf and k in oracle: + row["hf_vs_oracle"] = _max_abs_diff(hf[k], oracle[k]) + if dml is not None and k in dml and k in oracle: + row["oracle_vs_dml"] = _max_abs_diff(oracle[k], dml[k]) + if dml is not None and k in dml and k in hf: + row["hf_vs_dml"] = _max_abs_diff(hf[k], dml[k]) + rows.append(row) + return {"rows": rows} + + +def _print_table(report: dict, with_dml: bool) -> float: + """Pretty table; returns the worst diff seen anywhere.""" + cols = ["HF vs Oracle"] + if with_dml: + cols += ["Oracle vs DML", "HF vs DML"] + header = f" {'step':<10s} | " + " | ".join(f"{c:>14s}" for c in cols) + sep = " " + "-" * (len(header) - 2) + print(header) + print(sep) + + worst = 0.0 + + def fmt(v): + if v is None: + return f"{'-':>14s}" + return f"{v:>14.3e}" + + for row in report["rows"]: + cells = [fmt(row.get("hf_vs_oracle"))] + if with_dml: + cells.append(fmt(row.get("oracle_vs_dml"))) + cells.append(fmt(row.get("hf_vs_dml"))) + for k in ("hf_vs_oracle", "oracle_vs_dml", "hf_vs_dml"): + v = row.get(k) + if v is not None and not np.isnan(v): + worst = max(worst, v) + print(f" {row['key']:<10s} | " + " | ".join(cells)) + print(sep) + print(f" {'worst':<10s} | {worst:>14.3e}") + return worst + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def _build_argparser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser( + prog="compare_logits", + description="HF vs NumPy oracle vs DML driver -- per-layer parity check.", + ) + p.add_argument("prompt", nargs="?", default="Hello, my name is", + help="text prompt to encode (default: 'Hello, my name is')") + p.add_argument("--weights", default=None, + help="converted weights dir (default: scripts/staging/llm-native/weights/gpt2)") + p.add_argument("--tmax", type=int, default=16, + help="truncate tokenization to first N tokens (default: 16)") + p.add_argument("--tolerance", type=float, default=1e-9, + help="max-abs-diff threshold for PASS (default: 1e-9)") + p.add_argument("--with-dml", action="store_true", + help="also run the DML driver and add Oracle/DML and HF/DML columns") + p.add_argument("--systemds-bin", default=None, + help="path to bin/systemds (default: $SYSTEMDS_ROOT/bin/systemds)") + p.add_argument("--systemds-root", default=None, + help="repo root (default: walk up from this script)") + p.add_argument("--keep-tmp", action="store_true", + help="keep DML temp dir (useful for debugging diffs)") + p.add_argument("--json", action="store_true", + help="machine-readable JSON output instead of pretty table") + return p + + +def main(argv: Sequence[str] | None = None) -> int: + args = _build_argparser().parse_args(list(argv) if argv is not None else None) + + here = Path(__file__).resolve() + repo_root = Path(args.systemds_root) if args.systemds_root else here.parents[4] + weights = Path(args.weights) if args.weights \ + else repo_root / "scripts/staging/llm-native/weights/gpt2" + if not (weights / "manifest.json").exists(): + raise SystemExit(f"ERROR: no manifest.json under {weights} -- run convert_gpt2.py first") + + with open(weights / "manifest.json") as f: + manifest = json.load(f) + model_id = manifest["model"] + + ids = _tokenize(args.prompt, model_id, args.tmax) + if not args.json: + print(f"[compare] model={model_id} T={len(ids)} prompt={args.prompt!r}", + file=sys.stderr) + + hf_dumps, hf_t = _run_hf(ids, model_id) + or_dumps, or_t = _run_oracle(ids, weights) + + dml_dumps = None + dml_t = 0.0 + dml_tmp: Path | None = None + if args.with_dml: + if args.with_dml and "stacked" not in manifest: + raise SystemExit( + f"ERROR: {weights}/manifest.json has no 'stacked' entry -- " + "run pack_weights.py before --with-dml" + ) + sysds_bin = Path(args.systemds_bin) if args.systemds_bin \ + else repo_root / "bin/systemds" + dml_dumps, dml_t, dml_tmp = _run_dml(ids, weights, sysds_bin, repo_root, + keep_tmp=args.keep_tmp) + + report = _build_report(hf_dumps, or_dumps, dml_dumps) + report["meta"] = { + "model": model_id, + "prompt": args.prompt, + "T": len(ids), + "tolerance": args.tolerance, + "elapsed_sec": {"hf": hf_t, "oracle": or_t, "dml": dml_t}, + } + + if args.json: + print(json.dumps(report, indent=2)) + worst = max( + (v for row in report["rows"] + for k, v in row.items() if k in ("hf_vs_oracle", "oracle_vs_dml", "hf_vs_dml") + and v is not None and not (isinstance(v, float) and np.isnan(v))), + default=0.0, + ) + else: + worst = _print_table(report, with_dml=args.with_dml) + print() + print(f" HF run: {hf_t:6.2f} s") + print(f" Oracle run: {or_t:6.2f} s") + if args.with_dml: + print(f" DML run: {dml_t:6.2f} s") + if dml_tmp is not None and args.keep_tmp: + print(f" DML tmp dir: {dml_tmp}") + + ok = worst <= args.tolerance + if not args.json: + verdict = "PASS" if ok else "FAIL" + print(f"\n {verdict}: worst {worst:.3e} {'≤' if ok else '>'} tol {args.tolerance:.1e}") + return 0 if ok else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/staging/llm-native/tools/convert_gpt2.py b/scripts/staging/llm-native/tools/convert_gpt2.py new file mode 100644 index 00000000000..7831cc4b555 --- /dev/null +++ b/scripts/staging/llm-native/tools/convert_gpt2.py @@ -0,0 +1,251 @@ +# ------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# ------------------------------------------------------------- + +"""Convert a HuggingFace GPT-2 checkpoint to SystemDS CSV + MTD format. + +Each parameter tensor is written as a `.csv` file (comma-separated, +float64 by default) accompanied by a `.csv.mtd` JSON metadata file that +SystemDS uses to type the matrix. A top-level `manifest.json` records the +model config, file map, tied weights (LM head), and SHA-256 hashes. + +The shape conventions follow `scripts/nn/layers/affine.dml` and +`scripts/nn/layers/bert_layer.dml`: + + - 2-D weight matrices keep their HF shape (D, M); HF's `Conv1D` already + stores weights as (in, out), matching DML's affine `W : (D, M)`. + - 1-D bias / LayerNorm vectors are reshaped from (D,) to (1, D). + - Combined `attn.c_attn.weight` of shape (D, 3D) is column-sliced into + three (D, D) tensors W_Q, W_K, W_V; same for the bias vector. + - `lm_head.weight` is tied to `wte.weight` and recorded in the manifest + instead of being duplicated on disk. +""" + +from __future__ import annotations + +import argparse +import hashlib +import json +import os +import sys +from typing import Iterable + +import numpy as np + + +_MTD_TEMPLATE = { + "data_type": "matrix", + "value_type": "double", + "format": "csv", + "header": False, + "sep": ",", +} + + +def _ensure_2d(arr: np.ndarray) -> np.ndarray: + """SystemDS matrices are 2-D; promote 1-D vectors to (1, N) row vectors.""" + if arr.ndim == 1: + return arr.reshape(1, -1) + if arr.ndim == 2: + return arr + raise ValueError(f"unsupported tensor rank {arr.ndim}; only 1-D/2-D supported") + + +def _sha256_of(path: str) -> str: + h = hashlib.sha256() + with open(path, "rb") as f: + for chunk in iter(lambda: f.read(1 << 20), b""): + h.update(chunk) + return h.hexdigest() + + +def _write_matrix(out_dir: str, name: str, arr: np.ndarray, dtype: str) -> tuple[str, tuple[int, int]]: + """Write `/.csv` and `.csv.mtd`. Returns (relpath, shape).""" + arr = _ensure_2d(np.asarray(arr)) + if dtype == "float64": + arr = arr.astype(np.float64, copy=False) + value_type = "double" + elif dtype == "float32": + arr = arr.astype(np.float32, copy=False) + value_type = "single" + else: + raise ValueError(f"unsupported dtype: {dtype}") + + rel = f"{name}.csv" + csv_path = os.path.join(out_dir, rel) + mtd_path = csv_path + ".mtd" + + fmt = "%.17g" if dtype == "float64" else "%.9g" + np.savetxt(csv_path, arr, fmt=fmt, delimiter=",") + + mtd = dict(_MTD_TEMPLATE) + mtd["value_type"] = value_type + mtd["rows"] = int(arr.shape[0]) + mtd["cols"] = int(arr.shape[1]) + with open(mtd_path, "w") as f: + json.dump(mtd, f, indent=2) + + return rel, (int(arr.shape[0]), int(arr.shape[1])) + + +def _to_numpy(t): + """Detach a torch tensor to a contiguous CPU NumPy array (passthrough for ndarrays).""" + if isinstance(t, np.ndarray): + return t + return t.detach().to("cpu").contiguous().numpy() + + +def convert( + model_id: str, + out_dir: str, + dtype: str = "float64", + cache_dir: str | None = None, +) -> dict: + """Convert an HF GPT-2 checkpoint. Returns the in-memory manifest dict.""" + try: + from transformers import GPT2LMHeadModel + except ImportError as e: + raise SystemExit( + "ERROR: transformers is required. pip install -r requirements.txt" + ) from e + + os.makedirs(out_dir, exist_ok=True) + + print(f"[convert_gpt2] loading {model_id} ...", file=sys.stderr) + model = GPT2LMHeadModel.from_pretrained(model_id, cache_dir=cache_dir) + model.eval() + + # GPT2LMHeadModel.state_dict() prefixes inner-module keys with "transformer." + # and exposes lm_head.weight (which is tied to wte.weight). Use the inner + # GPT2Model directly so keys read like "wte.weight", "h.0.attn.c_attn.weight". + sd = model.transformer.state_dict() + cfg = model.config + D = cfg.n_embd + n_layer = cfg.n_layer + activation = getattr(cfg, "activation_function", "gelu_new") + + files: dict[str, str] = {} + shapes: dict[str, list[int]] = {} + + def emit(name: str, tensor) -> None: + rel, shape = _write_matrix(out_dir, name, _to_numpy(tensor), dtype) + files[name] = rel + shapes[name] = list(shape) + + # Embeddings (used outside the per-block forward). + emit("wte", sd["wte.weight"]) + emit("wpe", sd["wpe.weight"]) + + for i in range(n_layer): + # LayerNorm 1. + emit(f"h{i}_ln1_gamma", sd[f"h.{i}.ln_1.weight"]) + emit(f"h{i}_ln1_beta", sd[f"h.{i}.ln_1.bias"]) + + # Combined Q|K|V projection: split along the output (column) axis. + Wc = _to_numpy(sd[f"h.{i}.attn.c_attn.weight"]) # (D, 3D) + bc = _to_numpy(sd[f"h.{i}.attn.c_attn.bias"]) # (3D,) + if Wc.shape != (D, 3 * D): + raise RuntimeError( + f"unexpected c_attn.weight shape {Wc.shape}; expected ({D}, {3 * D})" + ) + emit(f"h{i}_W_Q", Wc[:, 0: D]) + emit(f"h{i}_W_K", Wc[:, D:2 * D]) + emit(f"h{i}_W_V", Wc[:, 2 * D:3 * D]) + emit(f"h{i}_b_Q", bc[ 0: D]) + emit(f"h{i}_b_K", bc[ D:2 * D]) + emit(f"h{i}_b_V", bc[ 2 * D:3 * D]) + + # Attention output projection (DML calls this W_context / b_context). + emit(f"h{i}_W_context", sd[f"h.{i}.attn.c_proj.weight"]) + emit(f"h{i}_b_context", sd[f"h.{i}.attn.c_proj.bias"]) + + # LayerNorm 2. + emit(f"h{i}_ln2_gamma", sd[f"h.{i}.ln_2.weight"]) + emit(f"h{i}_ln2_beta", sd[f"h.{i}.ln_2.bias"]) + + # Feed-forward (MLP). + emit(f"h{i}_W_intermediate", sd[f"h.{i}.mlp.c_fc.weight"]) + emit(f"h{i}_b_intermediate", sd[f"h.{i}.mlp.c_fc.bias"]) + emit(f"h{i}_W_out", sd[f"h.{i}.mlp.c_proj.weight"]) + emit(f"h{i}_b_out", sd[f"h.{i}.mlp.c_proj.bias"]) + + # Final LayerNorm. + emit("lnf_gamma", sd["ln_f.weight"]) + emit("lnf_beta", sd["ln_f.bias"]) + + # Compute SHA-256 hashes after all files are flushed. + hashes = {rel: _sha256_of(os.path.join(out_dir, rel)) for rel in files.values()} + + manifest = { + "model": model_id, + "arch": "gpt2-causal", + "dtype": dtype, + "config": { + "n_layer": int(cfg.n_layer), + "n_head": int(cfg.n_head), + "n_embd": int(cfg.n_embd), + "n_ctx": int(cfg.n_ctx), + "vocab_size": int(cfg.vocab_size), + "activation": str(activation), + "layer_norm_eps": float(getattr(cfg, "layer_norm_epsilon", 1e-5)), + }, + "tied": {"lm_head": "wte"}, + "files": files, + "shapes": shapes, + "sha256": hashes, + } + + manifest_path = os.path.join(out_dir, "manifest.json") + with open(manifest_path, "w") as f: + json.dump(manifest, f, indent=2, sort_keys=True) + + print( + f"[convert_gpt2] wrote {len(files)} matrices " + f"(+ manifest) to {out_dir}", + file=sys.stderr, + ) + return manifest + + +def _build_argparser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser( + prog="convert_gpt2", + description="Convert a HuggingFace GPT-2 checkpoint to SystemDS CSV + MTD.", + ) + p.add_argument("--model", default="gpt2", + help="HuggingFace model id or local path (default: gpt2)") + p.add_argument("--out", default=None, + help="output directory (default: weights/)") + p.add_argument("--dtype", choices=("float64", "float32"), default="float64", + help="numeric type to write (default: float64, matches DML)") + p.add_argument("--cache", default=None, + help="HuggingFace cache directory (default: HF default)") + return p + + +def main(argv: Iterable[str] | None = None) -> int: + args = _build_argparser().parse_args(list(argv) if argv is not None else None) + out = args.out or os.path.join("weights", os.path.basename(args.model.rstrip("/"))) + convert(model_id=args.model, out_dir=out, dtype=args.dtype, cache_dir=args.cache) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/staging/llm-native/tools/np_oracle_gpt2.py b/scripts/staging/llm-native/tools/np_oracle_gpt2.py new file mode 100644 index 00000000000..fb45bd6407e --- /dev/null +++ b/scripts/staging/llm-native/tools/np_oracle_gpt2.py @@ -0,0 +1,345 @@ +# ------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# ------------------------------------------------------------- + +"""NumPy reference forward pass for GPT-2 (debugger / oracle). + +Reads the CSV + MTD weights produced by ``convert_gpt2.py`` (indexed by the +sibling ``manifest.json``) and runs a pure-NumPy GPT-2 forward pass. The main +purpose is *debugging the DML implementation*: the oracle dumps every +intermediate hidden state (after embed, after each block's LN1/attn/LN2/output, +after the final LN, and the logits) so a later comparison harness can +pinpoint exactly which sublayer diverges. + +Secondary purpose: independently validate the converter against the original +HuggingFace model without DML in the loop (``--compare-hf``). + +Numerical conventions +--------------------- +* All math runs in float64, matching DML's native value type. Source CSVs + written by the converter are also float64 by default. +* GELU uses the tanh approximation (``gelu_new``), which matches HF's GPT-2 + and the formula in ``scripts/nn/layers/gelu.dml``. +* Attention scaling is by ``sqrt(d_head)`` (per-head dim), not ``sqrt(D)``. +* The LM head is tied to the token embedding: ``logits = h_final @ wte.T``. +""" + +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path +from typing import Sequence + +import numpy as np + + +# --------------------------------------------------------------------------- +# Manifest / weight loading +# --------------------------------------------------------------------------- + +def _load_manifest(weights_dir: Path) -> dict: + with open(weights_dir / "manifest.json") as f: + return json.load(f) + + +def _load_csv(weights_dir: Path, rel: str) -> np.ndarray: + """Load a converter-emitted CSV as a 2-D float64 ndarray.""" + return np.loadtxt(weights_dir / rel, delimiter=",", dtype=np.float64, ndmin=2) + + +def _load_all(weights_dir: Path, manifest: dict) -> dict[str, np.ndarray]: + """Eagerly load every matrix referenced by the manifest.""" + return {name: _load_csv(weights_dir, rel) for name, rel in manifest["files"].items()} + + +# --------------------------------------------------------------------------- +# Operators (match scripts/nn/layers/*.dml semantically) +# --------------------------------------------------------------------------- + +def _layer_norm(x: np.ndarray, gamma: np.ndarray, beta: np.ndarray, eps: float) -> np.ndarray: + """Per-token LayerNorm; gamma/beta are stored as (1, D) by the converter.""" + g = gamma.reshape(-1) + b = beta.reshape(-1) + mean = x.mean(axis=-1, keepdims=True) + var = x.var(axis=-1, keepdims=True) + return (x - mean) / np.sqrt(var + eps) * g + b + + +def _gelu_new(x: np.ndarray) -> np.ndarray: + """tanh approximation, matches HF gelu_new and scripts/nn/layers/gelu.dml.""" + c = np.sqrt(2.0 / np.pi) + return 0.5 * x * (1.0 + np.tanh(c * (x + 0.044715 * x ** 3))) + + +def _causal_self_attention( + x: np.ndarray, + W_Q: np.ndarray, b_Q: np.ndarray, + W_K: np.ndarray, b_K: np.ndarray, + W_V: np.ndarray, b_V: np.ndarray, + W_o: np.ndarray, b_o: np.ndarray, + n_head: int, +) -> np.ndarray: + """GPT-2 multi-head causal self-attention. Input/output shape: (T, D).""" + T, D = x.shape + d = D // n_head + + # Project. Biases stored as (1, D); broadcast on the row axis. + Q = x @ W_Q + b_Q.reshape(-1) + K = x @ W_K + b_K.reshape(-1) + V = x @ W_V + b_V.reshape(-1) + + # (T, D) -> (H, T, d): split last axis into heads, then move heads to front. + def split_heads(t: np.ndarray) -> np.ndarray: + return t.reshape(T, n_head, d).transpose(1, 0, 2) + + Qh, Kh, Vh = split_heads(Q), split_heads(K), split_heads(V) + + # Scaled dot-product per head: (H, T, T). + scores = Qh @ Kh.transpose(0, 2, 1) / np.sqrt(d) + + # Causal mask: zero out j > i (future tokens). + mask = np.triu(np.ones((T, T), dtype=bool), k=1) + scores = np.where(mask, -np.inf, scores) + + # Numerically stable softmax along the key (last) axis. + scores -= scores.max(axis=-1, keepdims=True) + probs = np.exp(scores) + probs /= probs.sum(axis=-1, keepdims=True) + + ctx = probs @ Vh # (H, T, d) + ctx = ctx.transpose(1, 0, 2).reshape(T, D) # (T, D) + + return ctx @ W_o + b_o.reshape(-1) + + +# --------------------------------------------------------------------------- +# Forward pass +# --------------------------------------------------------------------------- + +def forward( + token_ids: Sequence[int], + weights_dir: str | Path, + dump: bool = True, +) -> dict[str, np.ndarray]: + """Run a single forward pass over ``token_ids``. + + Returns a dict of intermediate tensors plus ``logits`` of shape (T, V). + Keys: + - ``embed`` : (T, D) wte+wpe + - ``h{i}_ln1`` : (T, D) post-LN1, pre-attn + - ``h{i}_attn`` : (T, D) post attention sub-block + (residual already added) + - ``h{i}_ln2`` : (T, D) post-LN2, pre-MLP + - ``h{i}_out`` : (T, D) post MLP sub-block (block out) + - ``lnf`` : (T, D) post final LayerNorm + - ``logits`` : (T, V) + """ + weights_dir = Path(weights_dir) + manifest = _load_manifest(weights_dir) + cfg = manifest["config"] + D, H, n_layer = cfg["n_embd"], cfg["n_head"], cfg["n_layer"] + n_ctx, vocab = cfg["n_ctx"], cfg["vocab_size"] + eps = float(cfg["layer_norm_eps"]) + + ids = np.asarray(list(token_ids), dtype=np.int64) + T = ids.shape[0] + if T > n_ctx: + raise ValueError(f"sequence length {T} exceeds n_ctx {n_ctx}") + if (ids < 0).any() or (ids >= vocab).any(): + raise ValueError(f"token ids out of range [0, {vocab})") + + W = _load_all(weights_dir, manifest) + wte = W["wte"] # (V, D) + wpe = W["wpe"] # (n_ctx, D) + + states: dict[str, np.ndarray] = {} + + # Embedding lookup: positions 0..T-1. + h = wte[ids] + wpe[np.arange(T)] + if dump: + states["embed"] = h.copy() + + # Pre-LN transformer blocks. + for i in range(n_layer): + # --- Attention sub-block --- + ln1 = _layer_norm(h, W[f"h{i}_ln1_gamma"], W[f"h{i}_ln1_beta"], eps) + if dump: + states[f"h{i}_ln1"] = ln1.copy() + + attn_out = _causal_self_attention( + ln1, + W[f"h{i}_W_Q"], W[f"h{i}_b_Q"], + W[f"h{i}_W_K"], W[f"h{i}_b_K"], + W[f"h{i}_W_V"], W[f"h{i}_b_V"], + W[f"h{i}_W_context"], W[f"h{i}_b_context"], + n_head=H, + ) + h = h + attn_out + if dump: + states[f"h{i}_attn"] = h.copy() + + # --- MLP sub-block --- + ln2 = _layer_norm(h, W[f"h{i}_ln2_gamma"], W[f"h{i}_ln2_beta"], eps) + if dump: + states[f"h{i}_ln2"] = ln2.copy() + + mlp_hidden = ln2 @ W[f"h{i}_W_intermediate"] + W[f"h{i}_b_intermediate"].reshape(-1) + mlp_hidden = _gelu_new(mlp_hidden) + mlp_out = mlp_hidden @ W[f"h{i}_W_out"] + W[f"h{i}_b_out"].reshape(-1) + h = h + mlp_out + if dump: + states[f"h{i}_out"] = h.copy() + + # Final LayerNorm + tied LM head. + h = _layer_norm(h, W["lnf_gamma"], W["lnf_beta"], eps) + if dump: + states["lnf"] = h.copy() + + logits = h @ wte.T # (T, V) + states["logits"] = logits + return states + + +# --------------------------------------------------------------------------- +# Optional: cross-check against HuggingFace +# --------------------------------------------------------------------------- + +def compare_with_hf( + states: dict[str, np.ndarray], + token_ids: Sequence[int], + model_id: str, + atol: float = 1e-4, + use_float64: bool = True, +) -> dict[str, float]: + """Run HF on the same tokens, return per-step max-abs-diff vs ``states``. + + HF's ``output_hidden_states`` returns ``n_layer + 1`` tensors: + idx 0 : post-embedding (matches our ``embed``) + idx 1..n_layer-1: input to block i (= output of block i-1, matches + our ``h{i-1}_out``) + idx n_layer : post-final-LayerNorm (matches our ``lnf``) + The pre-``ln_f`` output of the last block is *not* exposed by HF, so we + skip ``h{n_layer-1}_out`` here and rely on ``lnf`` and ``logits`` instead. + + By default we upcast HF to float64 -- otherwise per-block diffs of ~3e-3 + are pure float32 quantization noise, not a real correctness signal. + """ + try: + import torch + from transformers import GPT2LMHeadModel + except ImportError as e: + raise SystemExit( + "ERROR: torch + transformers required for --compare-hf" + ) from e + + ids = torch.tensor([list(token_ids)], dtype=torch.long) + model = GPT2LMHeadModel.from_pretrained(model_id) + if use_float64: + model = model.double() + model.eval() + with torch.no_grad(): + out = model(ids, output_hidden_states=True) + + hf_hidden = [h[0].numpy().astype(np.float64) for h in out.hidden_states] + hf_logits = out.logits[0].numpy().astype(np.float64) + n_layer = len(hf_hidden) - 1 + + diffs: dict[str, float] = {} + diffs["embed"] = float(np.abs(states["embed"] - hf_hidden[0]).max()) + for i in range(n_layer - 1): + key = f"h{i}_out" + if key in states: + diffs[key] = float(np.abs(states[key] - hf_hidden[i + 1]).max()) + if "lnf" in states: + diffs["lnf"] = float(np.abs(states["lnf"] - hf_hidden[n_layer]).max()) + diffs["logits"] = float(np.abs(states["logits"] - hf_logits).max()) + + precision = "float64" if use_float64 else "float32" + print(f"[oracle] HF cross-check (atol={atol}, hf={precision}):", file=sys.stderr) + worst = 0.0 + for k, v in diffs.items(): + flag = "OK " if v <= atol else "FAIL" + worst = max(worst, v) + print(f" {flag} {k:>10s} max|d| = {v:.3e}", file=sys.stderr) + print(f" worst = {worst:.3e}", file=sys.stderr) + return diffs + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def _parse_tokens(spec: str) -> list[int]: + """Parse '--tokens' as comma-separated ints or @path/to/file (one per line).""" + if spec.startswith("@"): + return [int(x) for x in Path(spec[1:]).read_text().split() if x] + return [int(x) for x in spec.split(",") if x.strip()] + + +def _build_argparser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser( + prog="np_oracle_gpt2", + description="Pure-NumPy GPT-2 forward pass over converted weights.", + ) + p.add_argument("--weights", required=True, + help="directory containing manifest.json + *.csv from convert_gpt2.py") + p.add_argument("--tokens", default="464,2068,7586,21831,625,262,16931,3290,13", + help="comma-separated token ids, or @file with whitespace-separated ids " + "(default: small fixed prompt for smoke testing)") + p.add_argument("--dump", default=None, + help="if set, write all intermediate states to /states.npz") + p.add_argument("--compare-hf", action="store_true", + help="cross-check logits + hidden states against HuggingFace") + p.add_argument("--atol", type=float, default=1e-4, + help="tolerance for --compare-hf (default: 1e-4)") + p.add_argument("--hf-float32", action="store_true", + help="run HF in float32 (default: cast HF to float64 to match oracle)") + return p + + +def main(argv: Sequence[str] | None = None) -> int: + args = _build_argparser().parse_args(list(argv) if argv is not None else None) + weights = Path(args.weights) + ids = _parse_tokens(args.tokens) + print(f"[oracle] forward over T={len(ids)} tokens from {weights}", file=sys.stderr) + + states = forward(ids, weights, dump=True) + + print(f"[oracle] logits shape = {states['logits'].shape} " + f"argmax(last) = {int(states['logits'][-1].argmax())}", file=sys.stderr) + + if args.dump: + out = Path(args.dump) + out.mkdir(parents=True, exist_ok=True) + np.savez(out / "states.npz", **states, token_ids=np.asarray(ids, dtype=np.int64)) + print(f"[oracle] wrote {len(states)} arrays to {out / 'states.npz'}", file=sys.stderr) + + if args.compare_hf: + manifest = _load_manifest(weights) + compare_with_hf(states, ids, manifest["model"], + atol=args.atol, use_float64=not args.hf_float32) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/staging/llm-native/tools/pack_weights.py b/scripts/staging/llm-native/tools/pack_weights.py new file mode 100644 index 00000000000..cda734924b6 --- /dev/null +++ b/scripts/staging/llm-native/tools/pack_weights.py @@ -0,0 +1,142 @@ +# ------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# ------------------------------------------------------------- + +"""Pack per-layer GPT-2 weight CSVs into stacked CSVs for the DML driver. + +SystemDS' DML parser requires ``read()`` filenames to be const-string- +traceable, which precludes loop-variable filename construction. The DML +driver therefore reads one stacked tensor per parameter type and row-slices +it inside the per-layer loop. + +For each parameter type ``P`` that appears once per layer, this script +emits ``all_

.csv`` (+ ``.csv.mtd``) of shape: + + - ``(n_layer * rows_per_layer, cols_per_layer)`` for 2-D matrices + - ``(n_layer, dim)`` for row-vector biases + (already 1xD per layer) + +The original per-layer files are untouched -- they remain the canonical +artifact for the NumPy oracle, the converter tests, and human inspection. +""" + +from __future__ import annotations + +import argparse +import json +import sys +from pathlib import Path +from typing import Iterable + +import numpy as np + + +# Per-layer parameter names emitted by ``convert_gpt2.py``. Keep in sync +# with the ``hi_*`` keys it writes; the DML driver expects exactly these +# stacked files. +_PER_LAYER_KEYS = ( + "ln1_gamma", "ln1_beta", + "W_Q", "b_Q", "W_K", "b_K", "W_V", "b_V", + "W_context", "b_context", + "ln2_gamma", "ln2_beta", + "W_intermediate", "b_intermediate", + "W_out", "b_out", +) + + +_MTD_TEMPLATE = { + "data_type": "matrix", + "value_type": "double", + "format": "csv", + "header": False, + "sep": ",", +} + + +def _load_csv(path: Path) -> np.ndarray: + return np.loadtxt(path, delimiter=",", dtype=np.float64, ndmin=2) + + +def _write_csv(path: Path, arr: np.ndarray) -> None: + np.savetxt(path, arr, fmt="%.17g", delimiter=",") + mtd = dict(_MTD_TEMPLATE) + mtd["rows"] = int(arr.shape[0]) + mtd["cols"] = int(arr.shape[1]) + with open(path.with_suffix(path.suffix + ".mtd"), "w") as f: + json.dump(mtd, f, indent=2) + + +def pack(weights_dir: str | Path) -> dict[str, list[int]]: + """Pack per-layer CSVs in ``weights_dir`` into ``all_*.csv`` files. + + Returns a dict ``{stacked_name: [rows, cols]}`` of what was written. + """ + weights_dir = Path(weights_dir) + with open(weights_dir / "manifest.json") as f: + manifest = json.load(f) + n_layer = int(manifest["config"]["n_layer"]) + + written: dict[str, list[int]] = {} + for key in _PER_LAYER_KEYS: + per_layer = [_load_csv(weights_dir / f"h{i}_{key}.csv") for i in range(n_layer)] + + # Sanity: all layers must agree on shape so vstack is unambiguous. + s0 = per_layer[0].shape + for i, t in enumerate(per_layer): + if t.shape != s0: + raise RuntimeError( + f"shape mismatch in {key}: layer 0 = {s0}, layer {i} = {t.shape}" + ) + + stacked = np.vstack(per_layer) + out = weights_dir / f"all_{key}.csv" + _write_csv(out, stacked) + written[f"all_{key}"] = [int(stacked.shape[0]), int(stacked.shape[1])] + print(f"[pack_weights] {out.name:<24s} ({stacked.shape[0]:>5d}, " + f"{stacked.shape[1]:>5d}) from {n_layer} layers of {s0}", + file=sys.stderr) + + # Store stacked layout in the manifest under a separate key so existing + # consumers (oracle, tests) are unaffected. + manifest.setdefault("stacked", {}) + manifest["stacked"]["per_layer_keys"] = list(_PER_LAYER_KEYS) + manifest["stacked"]["shapes"] = written + with open(weights_dir / "manifest.json", "w") as f: + json.dump(manifest, f, indent=2, sort_keys=True) + + print(f"[pack_weights] wrote {len(written)} stacked tensors to {weights_dir}", + file=sys.stderr) + return written + + +def main(argv: Iterable[str] | None = None) -> int: + p = argparse.ArgumentParser( + prog="pack_weights", + description="Stack per-layer GPT-2 CSVs into all_*.csv for the DML driver.", + ) + p.add_argument("--weights", required=True, + help="directory containing manifest.json + h{i}_*.csv") + args = p.parse_args(list(argv) if argv is not None else None) + pack(args.weights) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index e21c539d6d8..d24c0798c34 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -227,6 +227,7 @@ public enum Builtins { LMDS("lmDS", true), LMPREDICT("lmPredict", true), LMPREDICT_STATS("lmPredictStats", true), + LLMPREDICT("llmPredict", false, true), LOCAL("local", false), LOG("log", false), LOGSUMEXP("logSumExp", true), diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index 1b0536416d6..94055d055c5 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -204,6 +204,7 @@ public enum Opcodes { GROUPEDAGG("groupedagg", InstructionType.ParameterizedBuiltin), RMEMPTY("rmempty", InstructionType.ParameterizedBuiltin), REPLACE("replace", InstructionType.ParameterizedBuiltin), + LLMPREDICT("llmpredict", InstructionType.ParameterizedBuiltin), LOWERTRI("lowertri", InstructionType.ParameterizedBuiltin), UPPERTRI("uppertri", InstructionType.ParameterizedBuiltin), REXPAND("rexpand", InstructionType.ParameterizedBuiltin), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index 2e3543882d2..3414614991c 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -805,7 +805,7 @@ public static ReOrgOp valueOfByOpcode(String opcode) { /** Parameterized operations that require named variable arguments */ public enum ParamBuiltinOp { - AUTODIFF, CDF, CONTAINS, INVALID, INVCDF, GROUPEDAGG, RMEMPTY, REPLACE, REXPAND, + AUTODIFF, CDF, CONTAINS, INVALID, INVCDF, GROUPEDAGG, LLMPREDICT, RMEMPTY, REPLACE, REXPAND, LOWER_TRI, UPPER_TRI, TRANSFORMAPPLY, TRANSFORMDECODE, TRANSFORMCOLMAP, TRANSFORMMETA, TOKENIZE, TOSTRING, LIST, PARAMSERV diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java index 61a4b8b8f91..b791478214b 100644 --- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java +++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java @@ -187,6 +187,7 @@ public Lop constructLops() case LOWER_TRI: case UPPER_TRI: case TOKENIZE: + case LLMPREDICT: case TRANSFORMAPPLY: case TRANSFORMDECODE: case TRANSFORMCOLMAP: @@ -758,7 +759,7 @@ && getTargetHop().areDimsBelowThreshold() ) { if (_op == ParamBuiltinOp.TRANSFORMCOLMAP || _op == ParamBuiltinOp.TRANSFORMMETA || _op == ParamBuiltinOp.TOSTRING || _op == ParamBuiltinOp.LIST || _op == ParamBuiltinOp.CDF || _op == ParamBuiltinOp.INVCDF - || _op == ParamBuiltinOp.PARAMSERV) { + || _op == ParamBuiltinOp.PARAMSERV || _op == ParamBuiltinOp.LLMPREDICT) { _etype = ExecType.CP; } @@ -768,7 +769,7 @@ && getTargetHop().areDimsBelowThreshold() ) { switch(_op) { case CONTAINS: if(getTargetHop().optFindExecType() == ExecType.SPARK) - _etype = ExecType.SPARK; + _etype = ExecType.SPARK; break; default: // Do not change execution type. diff --git a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java index 3604121aac8..dcec28f76ca 100644 --- a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java +++ b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java @@ -176,6 +176,7 @@ public String getInstructions(String output) case CONTAINS: case REPLACE: case TOKENIZE: + case LLMPREDICT: case TRANSFORMAPPLY: case TRANSFORMDECODE: case TRANSFORMCOLMAP: diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index c6e7188d7bc..b1536371711 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2007,6 +2007,7 @@ private Hop processParameterizedBuiltinFunctionExpression(ParameterizedBuiltinFu case LOWER_TRI: case UPPER_TRI: case TOKENIZE: + case LLMPREDICT: case TRANSFORMAPPLY: case TRANSFORMDECODE: case TRANSFORMCOLMAP: diff --git a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java index 314440628e0..cd9699a1082 100644 --- a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java @@ -61,6 +61,7 @@ public class ParameterizedBuiltinFunctionExpression extends DataIdentifier pbHopMap.put(Builtins.GROUPEDAGG, ParamBuiltinOp.GROUPEDAGG); pbHopMap.put(Builtins.RMEMPTY, ParamBuiltinOp.RMEMPTY); pbHopMap.put(Builtins.REPLACE, ParamBuiltinOp.REPLACE); + pbHopMap.put(Builtins.LLMPREDICT, ParamBuiltinOp.LLMPREDICT); pbHopMap.put(Builtins.LOWER_TRI, ParamBuiltinOp.LOWER_TRI); pbHopMap.put(Builtins.UPPER_TRI, ParamBuiltinOp.UPPER_TRI); @@ -211,6 +212,10 @@ public void validateExpression(HashMap ids, HashMap valid = new HashSet<>(Arrays.asList( + "target", "url", "model", "max_tokens", "temperature", "top_p", "concurrency")); + checkInvalidParameters(getOpCode(), getVarParams(), valid); + checkDataType(false, "llmPredict", TF_FN_PARAM_DATA, DataType.FRAME, conditional); + checkStringParam(false, "llmPredict", "url", conditional); + + // validate numeric parameter types at compile time (when literal). + // Note: no range validation -- different LLM servers accept different + // ranges (e.g. vLLM allows temperature=0.0, OpenAI requires >0). + // Runtime errors from the server are more informative than + // compile-time checks locked to one server's rules. + checkNumericScalarParam("llmPredict", "max_tokens", conditional); + checkNumericScalarParam("llmPredict", "temperature", conditional); + checkNumericScalarParam("llmPredict", "top_p", conditional); + checkNumericScalarParam("llmPredict", "concurrency", conditional); + + output.setDataType(DataType.FRAME); + output.setValueType(ValueType.STRING); + output.setDimensions(-1, -1); + } + + private void checkNumericScalarParam(String fname, String pname, boolean conditional) { + Expression expr = getVarParam(pname); + if(expr == null) return; + if(expr instanceof DataIdentifier) { + DataIdentifier di = (DataIdentifier) expr; + if(di.getDataType() != null && !di.getDataType().isScalar()) { + raiseValidateError( + String.format("Function %s: parameter '%s' must be a scalar, got %s.", + fname, pname, di.getDataType()), conditional); + } + } + } + // example: A = transformapply(target=X, meta=M, spec=s) private void validateTransformApply(DataIdentifier output, boolean conditional) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/LlmPredictCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/LlmPredictCPInstruction.java new file mode 100644 index 00000000000..da2c123e89a --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/LlmPredictCPInstruction.java @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.cp; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.ConnectException; +import java.net.HttpURLConnection; +import java.net.MalformedURLException; +import java.net.SocketTimeoutException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.sysds.common.Types.DataType; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.lineage.LineageItem; +import org.apache.sysds.runtime.lineage.LineageItemUtils; +import org.apache.wink.json4j.JSONObject; + +public class LlmPredictCPInstruction extends ParameterizedBuiltinCPInstruction { + + protected LlmPredictCPInstruction(LinkedHashMap paramsMap, + CPOperand out, String opcode, String istr) { + super(null, paramsMap, out, opcode, istr); + } + + @Override + public void processInstruction(ExecutionContext ec) { + FrameBlock prompts = ec.getFrameInput(params.get("target")); + String url = params.get("url"); + String model = params.containsKey("model") ? + params.get("model") : null; + int maxTokens = params.containsKey("max_tokens") ? + Integer.parseInt(params.get("max_tokens")) : 512; + double temperature = params.containsKey("temperature") ? + Double.parseDouble(params.get("temperature")) : 0.0; + double topP = params.containsKey("top_p") ? + Double.parseDouble(params.get("top_p")) : 0.9; + int concurrency = params.containsKey("concurrency") ? + Integer.parseInt(params.get("concurrency")) : 1; + concurrency = Math.max(1, Math.min(concurrency, 128)); + + int n = prompts.getNumRows(); + String[][] data = new String[n][]; + + List> tasks = new ArrayList<>(n); + for(int i = 0; i < n; i++) { + String prompt = prompts.get(i, 0).toString(); + tasks.add(() -> callLlmEndpoint(prompt, url, model, maxTokens, temperature, topP)); + } + + try { + if(concurrency <= 1) { + for(int i = 0; i < n; i++) + data[i] = tasks.get(i).call(); + } + else { + ExecutorService pool = Executors.newFixedThreadPool( + Math.min(concurrency, n)); + List> futures = pool.invokeAll(tasks); + pool.shutdown(); + for(int i = 0; i < n; i++) + data[i] = futures.get(i).get(); + } + } + catch(DMLRuntimeException e) { + throw e; + } + catch(Exception e) { + throw new DMLRuntimeException("llmPredict failed: " + e.getMessage(), e); + } + + ValueType[] schema = {ValueType.STRING, ValueType.STRING, + ValueType.INT64, ValueType.INT64, ValueType.INT64}; + String[] colNames = {"prompt", "generated_text", "time_ms", "input_tokens", "output_tokens"}; + FrameBlock fbout = new FrameBlock(schema, colNames); + for(String[] row : data) + fbout.appendRow(row); + + ec.setFrameOutput(output.getName(), fbout); + ec.releaseFrameInput(params.get("target")); + } + + // No retry logic by design: as a database built-in, llmPredict should + // fail fast on transient errors and let the caller (DML script) decide + // whether and how to retry. Silent retries with backoff would make + // execution time unpredictable. + private static String[] callLlmEndpoint(String prompt, String url, + String model, int maxTokens, double temperature, double topP) { + long t0 = System.nanoTime(); + + // validate URL and open connection + HttpURLConnection conn; + try { + conn = (HttpURLConnection) new URI(url).toURL().openConnection(); + } + catch(URISyntaxException | MalformedURLException | IllegalArgumentException e) { + throw new DMLRuntimeException( + "llmPredict: invalid URL '" + url + "'. " + + "Expected format: http://host:port/v1/completions", e); + } + catch(IOException e) { + throw new DMLRuntimeException( + "llmPredict: cannot open connection to '" + url + "'.", e); + } + + try { + JSONObject req = new JSONObject(); + if(model != null) + req.put("model", model); + req.put("prompt", prompt); + req.put("max_tokens", maxTokens); + req.put("temperature", temperature); + req.put("top_p", topP); + + conn.setRequestMethod("POST"); + conn.setRequestProperty("Content-Type", "application/json"); + conn.setConnectTimeout(10_000); + conn.setReadTimeout(300_000); + conn.setDoOutput(true); + + try(OutputStream os = conn.getOutputStream()) { + os.write(req.toString().getBytes(StandardCharsets.UTF_8)); + } + + int httpCode = conn.getResponseCode(); + if(httpCode != 200) { + String errBody = ""; + try(InputStream es = conn.getErrorStream()) { + if(es != null) + errBody = new String(es.readAllBytes(), StandardCharsets.UTF_8); + } + catch(Exception ignored) {} + throw new DMLRuntimeException( + "llmPredict: endpoint returned HTTP " + httpCode + + " for '" + url + "'." + + (errBody.isEmpty() ? "" : " Response: " + errBody)); + } + + String body; + try(InputStream is = conn.getInputStream()) { + body = new String(is.readAllBytes(), StandardCharsets.UTF_8); + } + + JSONObject resp = new JSONObject(body); + if(!resp.has("choices") || resp.getJSONArray("choices").length() == 0) { + String errMsg = resp.has("error") ? resp.optString("error") : body; + throw new DMLRuntimeException( + "llmPredict: server response missing 'choices'. Response: " + errMsg); + } + String text = resp.getJSONArray("choices") + .getJSONObject(0).getString("text"); + long elapsed = (System.nanoTime() - t0) / 1_000_000; + int inTok = 0, outTok = 0; + if(resp.has("usage")) { + JSONObject usage = resp.getJSONObject("usage"); + inTok = usage.has("prompt_tokens") ? usage.getInt("prompt_tokens") : 0; + outTok = usage.has("completion_tokens") ? usage.getInt("completion_tokens") : 0; + } + return new String[]{prompt, text, + String.valueOf(elapsed), String.valueOf(inTok), String.valueOf(outTok)}; + } + catch(ConnectException e) { + throw new DMLRuntimeException( + "llmPredict: connection refused to '" + url + "'. " + + "Ensure the LLM server is running and reachable.", e); + } + catch(SocketTimeoutException e) { + throw new DMLRuntimeException( + "llmPredict: timed out connecting to '" + url + "'. " + + "Ensure the LLM server is running and reachable.", e); + } + catch(IOException e) { + throw new DMLRuntimeException( + "llmPredict: I/O error communicating with '" + url + "'.", e); + } + catch(DMLRuntimeException e) { + throw e; + } + catch(Exception e) { + throw new DMLRuntimeException( + "llmPredict: failed to get response from '" + url + "'.", e); + } + finally { + conn.disconnect(); + } + } + + @Override + public Pair getLineageItem(ExecutionContext ec) { + CPOperand target = new CPOperand(params.get("target"), ValueType.STRING, DataType.FRAME); + CPOperand urlOp = new CPOperand(params.get("url"), ValueType.STRING, DataType.SCALAR, true); + return Pair.of(output.getName(), + new LineageItem(getOpcode(), LineageItemUtils.getLineage(ec, target, urlOp))); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java index 119589a3033..ac2f527f06c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java @@ -158,6 +158,9 @@ else if(opcode.equals(Opcodes.TRANSFORMAPPLY.toString()) || opcode.equals(Opcode || opcode.equals(Opcodes.TOSTRING.toString()) || opcode.equals(Opcodes.NVLIST.toString()) || opcode.equals(Opcodes.AUTODIFF.toString())) { return new ParameterizedBuiltinCPInstruction(null, paramsMap, out, opcode, str); } + else if(opcode.equals(Opcodes.LLMPREDICT.toString())) { + return new LlmPredictCPInstruction(paramsMap, out, opcode, str); + } else if(Opcodes.PARAMSERV.toString().equals(opcode)) { return new ParamservBuiltinCPInstruction(null, paramsMap, out, opcode, str); } @@ -324,6 +327,7 @@ else if(opcode.equalsIgnoreCase(Opcodes.TOKENIZE.toString())) { ec.setFrameOutput(output.getName(), fbout); ec.releaseFrameInput(params.get("target")); } + else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMAPPLY.toString())) { // acquire locks FrameBlock data = ec.getFrameInput(params.get("target")); diff --git a/src/main/python/llm_server.py b/src/main/python/llm_server.py new file mode 100644 index 00000000000..b538d871ba8 --- /dev/null +++ b/src/main/python/llm_server.py @@ -0,0 +1,117 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +"""Local inference server for llmPredict. Loads a HuggingFace model +and serves it at http://localhost:PORT/v1/completions. + +Usage: python llm_server.py distilgpt2 --port 8080 +""" + +import argparse +import json +import sys +import time +from http.server import HTTPServer, BaseHTTPRequestHandler + +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM + + +class InferenceHandler(BaseHTTPRequestHandler): + + def do_POST(self): + if self.path != "/v1/completions": + self.send_error(404) + return + length = int(self.headers.get("Content-Length", 0)) + body = json.loads(self.rfile.read(length)) + + prompt = body.get("prompt", "") + max_tokens = int(body.get("max_tokens", 512)) + temperature = float(body.get("temperature", 0.0)) + top_p = float(body.get("top_p", 0.9)) + + model = self.server.model + tokenizer = self.server.tokenizer + + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + input_len = inputs["input_ids"].shape[1] + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=max_tokens, + temperature=temperature if temperature > 0 else 1.0, + top_p=top_p, + do_sample=temperature > 0, + ) + new_tokens = outputs[0][input_len:] + text = tokenizer.decode(new_tokens, skip_special_tokens=True) + + resp = { + "choices": [{"text": text}], + "usage": { + "prompt_tokens": input_len, + "completion_tokens": len(new_tokens), + }, + } + payload = json.dumps(resp).encode("utf-8") + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(payload))) + self.end_headers() + self.wfile.write(payload) + + def log_message(self, fmt, *args): + sys.stderr.write("[llm_server] %s\n" % (fmt % args)) + + +def main(): + parser = argparse.ArgumentParser(description="OpenAI-compatible LLM server") + parser.add_argument("model", help="HuggingFace model name") + parser.add_argument("--port", type=int, default=8080) + args = parser.parse_args() + + print(f"Loading model: {args.model}", flush=True) + tokenizer = AutoTokenizer.from_pretrained(args.model) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if torch.cuda.is_available(): + print(f"CUDA available: {torch.cuda.device_count()} GPU(s)", flush=True) + model = AutoModelForCausalLM.from_pretrained( + args.model, device_map="auto", torch_dtype=torch.float16) + else: + model = AutoModelForCausalLM.from_pretrained(args.model) + model.eval() + print(f"Model loaded on {next(model.parameters()).device}", flush=True) + + server = HTTPServer(("0.0.0.0", args.port), InferenceHandler) + server.model = model + server.tokenizer = tokenizer + print(f"Serving on http://0.0.0.0:{args.port}/v1/completions", flush=True) + try: + server.serve_forever() + except KeyboardInterrupt: + print("Shutting down", flush=True) + server.server_close() + + +if __name__ == "__main__": + main() diff --git a/src/test/java/org/apache/sysds/test/applications/nn/transformers/MultiAttentionLayerTest.java b/src/test/java/org/apache/sysds/test/applications/nn/transformers/MultiAttentionLayerTest.java index 225d8983aa9..37c4197ca4d 100644 --- a/src/test/java/org/apache/sysds/test/applications/nn/transformers/MultiAttentionLayerTest.java +++ b/src/test/java/org/apache/sysds/test/applications/nn/transformers/MultiAttentionLayerTest.java @@ -27,6 +27,7 @@ public class MultiAttentionLayerTest extends AutomatedTestBase { private static final String TEST_NAME_FORWARD = "multi_attention_forward"; private static final String TEST_NAME_BACKWARD = "multi_attention_backward"; + private static final String TEST_NAME_FORWARD_CAUSAL = "multi_attention_forward_causal"; private static final String TEST_DIR = "applications/nn/component/"; private static final String RESOURCE_DIR = "src/test/resources/component/transformers/multi_attention_layer/"; @@ -35,6 +36,7 @@ public void setUp() { TestUtils.clearAssertionInformation(); addTestConfiguration(TEST_NAME_FORWARD, new TestConfiguration(TEST_DIR, TEST_NAME_FORWARD)); addTestConfiguration(TEST_NAME_BACKWARD, new TestConfiguration(TEST_DIR, TEST_NAME_BACKWARD)); + addTestConfiguration(TEST_NAME_FORWARD_CAUSAL, new TestConfiguration(TEST_DIR, TEST_NAME_FORWARD_CAUSAL)); } @Test @@ -67,6 +69,41 @@ public void testMultiAttentionBackwardSmall() { runMultiAttentionTest("test6", 1, 1, 1, 1, 0, TEST_NAME_BACKWARD, 1e-5, false); } + @Test + public void testMultiAttentionForwardCausalMask() { + Types.ExecMode platformOld = setExecMode(Types.ExecMode.SINGLE_NODE); + try { + getAndLoadTestConfiguration(TEST_NAME_FORWARD_CAUSAL); + fullDMLScriptName = getScript(); + programArgs = new String[] { + "-stats", "-args", + output("causal_token1_diff"), + output("causal_token3_diff"), + output("noncausal_token1_diff"), + }; + + runTest(true, EXCEPTION_NOT_EXPECTED, null, -1); + + double causalToken1Diff = + (Double) readDMLScalarFromOutputDir("causal_token1_diff").values().toArray()[0]; + double causalToken3Diff = + (Double) readDMLScalarFromOutputDir("causal_token3_diff").values().toArray()[0]; + double noncausalToken1Diff = + (Double) readDMLScalarFromOutputDir("noncausal_token1_diff").values().toArray()[0]; + + assert causalToken1Diff < 1e-6; + assert causalToken3Diff > 1e-6; + assert noncausalToken1Diff > 1e-6; + } + catch (Throwable ex) { + ex.printStackTrace(System.out); + throw new RuntimeException(ex); + } + finally { + resetExecMode(platformOld); + } + } + private void runMultiAttentionTest(String testSuffix, int batchSize, int seqLength, int numHeads, int embeddingDim, int debug, String testname, double precision, boolean isForward) { // Set execution platform diff --git a/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java new file mode 100644 index 00000000000..bc7817a7d17 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/jmlc/JMLCLLMInferenceTest.java @@ -0,0 +1,572 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.jmlc; + +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; + +import com.sun.net.httpserver.HttpServer; + +import org.apache.sysds.api.jmlc.Connection; +import org.apache.sysds.api.jmlc.PreparedScript; +import org.apache.sysds.api.jmlc.ResultVariables; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.test.AutomatedTestBase; +import org.junit.Assert; +import org.junit.Test; + +/** + * Tests for llmPredict built-in via JMLC. + * Needs an OpenAI-compatible server on localhost:8080. + */ +public class JMLCLLMInferenceTest extends AutomatedTestBase { + private final static String TEST_NAME = "JMLCLLMInferenceTest"; + private final static String TEST_DIR = "functions/jmlc/"; + private final static String LLM_URL = "http://localhost:8080/v1/completions"; + + private final static String DML_SCRIPT = + "prompts = read(\"prompts\", data_type=\"frame\")\n" + + + "results = llmPredict(target=prompts, url=$url, max_tokens=$mt, temperature=$temp, top_p=$tp)\n" + + "write(results, \"results\")"; + + @Override + public void setUp() { + addTestConfiguration(TEST_DIR, TEST_NAME); + getAndLoadTestConfiguration(TEST_NAME); + } + + @Test + public void testSinglePrompt() { + Connection conn = null; + try { + conn = new Connection(); + Map args = new HashMap<>(); + args.put("$url", LLM_URL); + args.put("$mt", "20"); + args.put("$temp", "0.7"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + + String[][] promptData = new String[][]{{"The meaning of life is"}}; + ps.setFrame("prompts", promptData); + + ResultVariables rv = ps.executeScript(); + FrameBlock result = rv.getFrameBlock("results"); + + Assert.assertNotNull("Result should not be null", result); + Assert.assertEquals("Should have 1 row", 1, result.getNumRows()); + Assert.assertEquals("Should have 5 columns", 5, result.getNumColumns()); + String generated = result.get(0, 1).toString(); + Assert.assertFalse("Generated text should not be empty", generated.isEmpty()); + + System.out.println("Prompt: " + promptData[0][0]); + System.out.println("Generated: " + generated); + } catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException("LLM server not available", e); + } finally { + if (conn != null) conn.close(); + } + } + + @Test + public void testServerUnreachable() { + // should throw DMLRuntimeException, not hang + Connection conn = null; + try { + conn = new Connection(); + String deadUrl = "http://localhost:19999/v1/completions"; + Map args = new HashMap<>(); + args.put("$url", deadUrl); + args.put("$mt", "20"); + args.put("$temp", "0.0"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + + String[][] promptData = new String[][]{{"Hello"}}; + ps.setFrame("prompts", promptData); + + try { + ps.executeScript(); + Assert.fail("Expected DMLRuntimeException for unreachable server"); + } + catch (DMLRuntimeException e) { + String fullMsg = getExceptionChainMessage(e); + System.out.println("Correctly caught: " + fullMsg); + Assert.assertTrue("Error should mention connection issue", + fullMsg.contains("connection refused") + || fullMsg.contains("Connection refused") + || fullMsg.contains("server is running")); + } + } + catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException( + "Could not set up negative test", e); + } + finally { + if (conn != null) conn.close(); + } + } + + @Test + public void testInvalidUrl() { + Connection conn = null; + try { + conn = new Connection(); + Map args = new HashMap<>(); + args.put("$url", "not-a-valid-url"); + args.put("$mt", "20"); + args.put("$temp", "0.0"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + + String[][] promptData = new String[][]{{"Hello"}}; + ps.setFrame("prompts", promptData); + + try { + ps.executeScript(); + Assert.fail("Expected DMLRuntimeException for invalid URL"); + } + catch (DMLRuntimeException e) { + String fullMsg = getExceptionChainMessage(e); + System.out.println("Correctly caught: " + fullMsg); + Assert.assertTrue("Error should mention invalid URL", + fullMsg.contains("invalid URL") + || fullMsg.contains("Invalid URL")); + } + } + catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException( + "Could not set up negative test", e); + } + finally { + if (conn != null) conn.close(); + } + } + + private static String getExceptionChainMessage(Throwable t) { + StringBuilder sb = new StringBuilder(); + while(t != null) { + if(sb.length() > 0) sb.append(" | "); + if(t.getMessage() != null) sb.append(t.getMessage()); + t = t.getCause(); + } + return sb.toString(); + } + + @Test + public void testConcurrency() { + Connection conn = null; + try { + conn = new Connection(); + String dmlConc = + "prompts = read(\"prompts\", data_type=\"frame\")\n" + + "results = llmPredict(target=prompts, url=$url, max_tokens=$mt, " + + "temperature=$temp, top_p=$tp, concurrency=$conc)\n" + + "write(results, \"results\")"; + Map args = new HashMap<>(); + args.put("$url", LLM_URL); + args.put("$mt", "20"); + args.put("$temp", "0.0"); + args.put("$tp", "0.9"); + args.put("$conc", "2"); + PreparedScript ps = conn.prepareScript(dmlConc, args, + new String[]{"prompts"}, new String[]{"results"}); + + String[][] promptData = new String[][]{ + {"Hello world"}, {"Test prompt"}, {"Another test"} + }; + ps.setFrame("prompts", promptData); + + ResultVariables rv = ps.executeScript(); + FrameBlock result = rv.getFrameBlock("results"); + + Assert.assertNotNull("Result should not be null", result); + Assert.assertEquals("Should have 3 rows", 3, result.getNumRows()); + Assert.assertEquals("Should have 5 columns", 5, result.getNumColumns()); + } catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException("LLM server not available", e); + } finally { + if (conn != null) conn.close(); + } + } + + @Test + public void testHttpErrorResponse() { + // mock server that returns HTTP 500 + HttpServer server = null; + Connection conn = null; + try { + server = HttpServer.create(new InetSocketAddress(0), 0); + int port = server.getAddress().getPort(); + server.createContext("/v1/completions", exchange -> { + byte[] resp = "{\"error\": \"internal server error\"}".getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(500, resp.length); + try(OutputStream os = exchange.getResponseBody()) { + os.write(resp); + } + }); + server.start(); + + conn = new Connection(); + Map args = new HashMap<>(); + args.put("$url", "http://localhost:" + port + "/v1/completions"); + args.put("$mt", "20"); + args.put("$temp", "0.0"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + ps.setFrame("prompts", new String[][]{{"Hello"}}); + + try { + ps.executeScript(); + Assert.fail("Expected DMLRuntimeException for HTTP 500"); + } + catch (DMLRuntimeException e) { + String fullMsg = getExceptionChainMessage(e); + System.out.println("Correctly caught HTTP 500: " + fullMsg); + Assert.assertTrue("Error should mention HTTP 500", + fullMsg.contains("HTTP 500")); + } + } + catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException( + "Could not set up mock server", e); + } + finally { + if (server != null) server.stop(0); + if (conn != null) conn.close(); + } + } + + @Test + public void testMalformedJsonResponse() { + // mock server that returns HTTP 200 with invalid JSON + HttpServer server = null; + Connection conn = null; + try { + server = HttpServer.create(new InetSocketAddress(0), 0); + int port = server.getAddress().getPort(); + server.createContext("/v1/completions", exchange -> { + byte[] resp = "this is not json at all".getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(200, resp.length); + try(OutputStream os = exchange.getResponseBody()) { + os.write(resp); + } + }); + server.start(); + + conn = new Connection(); + Map args = new HashMap<>(); + args.put("$url", "http://localhost:" + port + "/v1/completions"); + args.put("$mt", "20"); + args.put("$temp", "0.0"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + ps.setFrame("prompts", new String[][]{{"Hello"}}); + + try { + ps.executeScript(); + Assert.fail("Expected DMLRuntimeException for malformed JSON"); + } + catch (DMLRuntimeException e) { + String fullMsg = getExceptionChainMessage(e); + System.out.println("Correctly caught malformed JSON: " + fullMsg); + Assert.assertTrue("Error should mention response issue", + fullMsg.contains("failed") || fullMsg.contains("response")); + } + } + catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException( + "Could not set up mock server", e); + } + finally { + if (server != null) server.stop(0); + if (conn != null) conn.close(); + } + } + + @Test + public void testMissingChoicesInResponse() { + // mock server that returns valid JSON but no "choices" array + HttpServer server = null; + Connection conn = null; + try { + server = HttpServer.create(new InetSocketAddress(0), 0); + int port = server.getAddress().getPort(); + server.createContext("/v1/completions", exchange -> { + byte[] resp = "{\"id\": \"test\", \"object\": \"text_completion\"}" + .getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(200, resp.length); + try(OutputStream os = exchange.getResponseBody()) { + os.write(resp); + } + }); + server.start(); + + conn = new Connection(); + Map args = new HashMap<>(); + args.put("$url", "http://localhost:" + port + "/v1/completions"); + args.put("$mt", "20"); + args.put("$temp", "0.0"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + ps.setFrame("prompts", new String[][]{{"Hello"}}); + + try { + ps.executeScript(); + Assert.fail("Expected DMLRuntimeException for missing choices"); + } + catch (DMLRuntimeException e) { + String fullMsg = getExceptionChainMessage(e); + System.out.println("Correctly caught missing choices: " + fullMsg); + Assert.assertTrue("Error should mention missing choices", + fullMsg.contains("choices")); + } + } + catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException( + "Could not set up mock server", e); + } + finally { + if (server != null) server.stop(0); + if (conn != null) conn.close(); + } + } + + @Test + public void testBatchInference() { + Connection conn = null; + try { + conn = new Connection(); + Map args = new HashMap<>(); + args.put("$url", LLM_URL); + args.put("$mt", "20"); + args.put("$temp", "0.7"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + + String[] prompts = { + "The meaning of life is", + "Machine learning is", + "Apache SystemDS enables" + }; + String[][] promptData = new String[prompts.length][1]; + for (int i = 0; i < prompts.length; i++) + promptData[i][0] = prompts[i]; + ps.setFrame("prompts", promptData); + + ResultVariables rv = ps.executeScript(); + FrameBlock result = rv.getFrameBlock("results"); + + Assert.assertNotNull("Result should not be null", result); + Assert.assertEquals("Should have 3 rows", 3, result.getNumRows()); + Assert.assertEquals("Should have 5 columns", 5, result.getNumColumns()); + + for (int i = 0; i < prompts.length; i++) { + String prompt = result.get(i, 0).toString(); + String generated = result.get(i, 1).toString(); + long timeMs = Long.parseLong(result.get(i, 2).toString()); + Assert.assertEquals("Prompt should match", prompts[i], prompt); + Assert.assertFalse("Generated text should not be empty", generated.isEmpty()); + Assert.assertTrue("Time should be positive", timeMs > 0); + System.out.println("Prompt: " + prompt); + System.out.println("Generated: " + generated + " (" + timeMs + "ms)"); + } + } catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException("LLM server not available", e); + } finally { + if (conn != null) conn.close(); + } + } + + @Test + public void testMockSinglePrompt() { + // mock server that returns a valid OpenAI-compatible response + // runs in CI without a real LLM server + HttpServer server = null; + Connection conn = null; + try { + server = HttpServer.create(new InetSocketAddress(0), 0); + int port = server.getAddress().getPort(); + server.createContext("/v1/completions", exchange -> { + String body = "{\"choices\":[{\"text\":\"42 is the answer\"}]," + + "\"usage\":{\"prompt_tokens\":5,\"completion_tokens\":4}}"; + byte[] resp = body.getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(200, resp.length); + try(OutputStream os = exchange.getResponseBody()) { + os.write(resp); + } + }); + server.start(); + + conn = new Connection(); + Map args = new HashMap<>(); + args.put("$url", "http://localhost:" + port + "/v1/completions"); + args.put("$mt", "20"); + args.put("$temp", "0.0"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + ps.setFrame("prompts", new String[][]{{"What is 6 times 7?"}}); + + ResultVariables rv = ps.executeScript(); + FrameBlock result = rv.getFrameBlock("results"); + + Assert.assertNotNull("Result should not be null", result); + Assert.assertEquals("Should have 1 row", 1, result.getNumRows()); + Assert.assertEquals("Should have 5 columns", 5, result.getNumColumns()); + Assert.assertEquals("Generated text should match", "42 is the answer", + result.get(0, 1).toString()); + Assert.assertEquals("Input tokens should be 5", "5", + result.get(0, 3).toString()); + Assert.assertEquals("Output tokens should be 4", "4", + result.get(0, 4).toString()); + } + catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException( + "Could not set up mock server", e); + } + finally { + if (server != null) server.stop(0); + if (conn != null) conn.close(); + } + } + + @Test + public void testMockBatchPrompts() { + // mock server returning different responses per prompt + HttpServer server = null; + Connection conn = null; + try { + server = HttpServer.create(new InetSocketAddress(0), 0); + int port = server.getAddress().getPort(); + server.createContext("/v1/completions", exchange -> { + // read request to get prompt + String reqBody = new String(exchange.getRequestBody().readAllBytes(), + StandardCharsets.UTF_8); + String response; + if (reqBody.contains("first")) + response = "response-1"; + else if (reqBody.contains("second")) + response = "response-2"; + else + response = "response-3"; + String body = "{\"choices\":[{\"text\":\"" + response + "\"}]," + + "\"usage\":{\"prompt_tokens\":3,\"completion_tokens\":1}}"; + byte[] resp = body.getBytes(StandardCharsets.UTF_8); + exchange.sendResponseHeaders(200, resp.length); + try(OutputStream os = exchange.getResponseBody()) { + os.write(resp); + } + }); + server.start(); + + conn = new Connection(); + Map args = new HashMap<>(); + args.put("$url", "http://localhost:" + port + "/v1/completions"); + args.put("$mt", "20"); + args.put("$temp", "0.0"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + ps.setFrame("prompts", new String[][]{ + {"first prompt"}, {"second prompt"}, {"third prompt"} + }); + + ResultVariables rv = ps.executeScript(); + FrameBlock result = rv.getFrameBlock("results"); + + Assert.assertEquals("Should have 3 rows", 3, result.getNumRows()); + Assert.assertEquals("Row 0 text", "response-1", result.get(0, 1).toString()); + Assert.assertEquals("Row 1 text", "response-2", result.get(1, 1).toString()); + Assert.assertEquals("Row 2 text", "response-3", result.get(2, 1).toString()); + } + catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException( + "Could not set up mock server", e); + } + finally { + if (server != null) server.stop(0); + if (conn != null) conn.close(); + } + } + + @Test + public void testEmptyPromptFrame() { + // empty frame (0 rows) should produce empty result, not crash + HttpServer server = null; + Connection conn = null; + try { + server = HttpServer.create(new InetSocketAddress(0), 0); + int port = server.getAddress().getPort(); + server.createContext("/v1/completions", exchange -> { + // should never be called for 0 prompts + Assert.fail("Server should not be called for empty frame"); + }); + server.start(); + + conn = new Connection(); + Map args = new HashMap<>(); + args.put("$url", "http://localhost:" + port + "/v1/completions"); + args.put("$mt", "20"); + args.put("$temp", "0.0"); + args.put("$tp", "0.9"); + PreparedScript ps = conn.prepareScript(DML_SCRIPT, args, + new String[]{"prompts"}, new String[]{"results"}); + ps.setFrame("prompts", new String[0][1]); + + ResultVariables rv = ps.executeScript(); + FrameBlock result = rv.getFrameBlock("results"); + + Assert.assertNotNull("Result should not be null", result); + Assert.assertEquals("Should have 0 rows", 0, result.getNumRows()); + } + catch (Exception e) { + e.printStackTrace(); + org.junit.Assume.assumeNoException( + "Could not set up test", e); + } + finally { + if (server != null) server.stop(0); + if (conn != null) conn.close(); + } + } +} diff --git a/src/test/scripts/applications/nn/component/multi_attention_forward_causal.dml b/src/test/scripts/applications/nn/component/multi_attention_forward_causal.dml new file mode 100644 index 00000000000..80db7f4f7c6 --- /dev/null +++ b/src/test/scripts/applications/nn/component/multi_attention_forward_causal.dml @@ -0,0 +1,60 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +source("scripts/nn/layers/multi_attention.dml") as multi_attention + +# Small deterministic setup: +# B=1, H=1, T=3, D=2 => flattened shape is 1 x (T*H*D) = 1 x 6 +B = 1 +H = 1 +T = 3 +D = 2 +dropout_p = 0.0 + +# Use non-uniform Q/K to avoid degenerate/symmetric attention patterns. +Q = matrix("1 2 3 4 5 6", rows=B, cols=T*H*D) +K = matrix("2 1 4 3 6 5", rows=B, cols=T*H*D) + +# V1 and V2 differ only in future tokens (token2, token3) +V1 = matrix("1 10 2 20 3 30", rows=B, cols=T*H*D) +V2 = matrix("1 10 200 2000 300 3000", rows=B, cols=T*H*D) + +[context_causal_1, A1, M1] = multi_attention::forward_causal(Q, K, V1, H, T, D, dropout_p) +[context_causal_2, A2, M2] = multi_attention::forward_causal(Q, K, V2, H, T, D, dropout_p) +[context_noncausal_1, A3, M3] = multi_attention::forward(Q, K, V1, H, T, D, dropout_p) +[context_noncausal_2, A4, M4] = multi_attention::forward(Q, K, V2, H, T, D, dropout_p) + +# Reconstruct token-major shape from flattened context for B=1, H=1. +C_causal_1 = matrix(context_causal_1[1,], rows=T, cols=H*D) +C_causal_2 = matrix(context_causal_2[1,], rows=T, cols=H*D) +C_noncausal_1 = matrix(context_noncausal_1[1,], rows=T, cols=H*D) +C_noncausal_2 = matrix(context_noncausal_2[1,], rows=T, cols=H*D) + +# Future token changes should not affect first token in causal mode. +causal_token1_diff = max(abs(C_causal_1[1,] - C_causal_2[1,])) +# But they should affect later tokens. +causal_token3_diff = max(abs(C_causal_1[3,] - C_causal_2[3,])) +# In non-causal mode, first token can attend to all tokens and should change. +noncausal_token1_diff = max(abs(C_noncausal_1[1,] - C_noncausal_2[1,])) + +write(causal_token1_diff, $1, format="text") +write(causal_token3_diff, $2, format="text") +write(noncausal_token1_diff, $3, format="text")