From a7c59523e2063982ce56b2771479c04636c54830 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 29 Dec 2022 15:26:46 +0100 Subject: [PATCH 01/21] move multiheadattention from Metalhead --- src/layers/attention.jl | 63 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 src/layers/attention.jl diff --git a/src/layers/attention.jl b/src/layers/attention.jl new file mode 100644 index 0000000000..dc7dac18f7 --- /dev/null +++ b/src/layers/attention.jl @@ -0,0 +1,63 @@ +""" + MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, + attn_dropout_prob = 0., proj_dropout_prob = 0.) + +Multi-head self-attention layer. + +# Arguments + +- `planes`: number of input channels +- `nheads`: number of heads +- `qkv_bias`: whether to use bias in the layer to get the query, key and value +- `attn_dropout_prob`: dropout probability after the self-attention layer +- `proj_dropout_prob`: dropout probability after the projection layer +""" +struct MultiHeadAttention{P, Q, R} + nheads::Int + qkv_layer::P + attn_drop::Q + projection::R +end + +@functor MHAttention + +function MultiHeadAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, + attn_dropout_prob = 0.0, proj_dropout_prob = 0.0) + @assert planes % nheads==0 "planes should be divisible by nheads" + qkv_layer = Dense(planes, planes * 3; bias = qkv_bias) + attn_drop = Dropout(attn_dropout_prob) + proj = Chain(Dense(planes, planes), Dropout(proj_dropout_prob)) + return MultiHeadAttention(nheads, qkv_layer, attn_drop, proj) +end + +function (m::MultiHeadAttention)(x::AbstractArray{T, 3}) where {T} + nfeatures, seq_len, batch_size = size(x) + x_reshaped = reshape(x, nfeatures, seq_len * batch_size) + qkv = m.qkv_layer(x_reshaped) + qkv_reshaped = reshape(qkv, nfeatures ÷ m.nheads, m.nheads, seq_len, 3 * batch_size) + query, key, value = chunk(qkv_reshaped, 3; dims = 4) + scale = convert(T, sqrt(size(query, 1) / m.nheads)) + key_reshaped = reshape(permutedims(key, (2, 1, 3, 4)), m.nheads, nfeatures ÷ m.nheads, + seq_len * batch_size) + query_reshaped = reshape(permutedims(query, (1, 2, 3, 4)), nfeatures ÷ m.nheads, + m.nheads, seq_len * batch_size) + + attention = softmax(batched_mul(query_reshaped, key_reshaped) .* scale) + attention = m.attn_drop(attention) + + value_reshaped = reshape(permutedims(value, (1, 2, 3, 4)), nfeatures ÷ m.nheads, + m.nheads, seq_len * batch_size) + pre_projection = reshape(batched_mul(attention, value_reshaped), + (nfeatures, seq_len, batch_size)) + y = m.projection(reshape(pre_projection, size(pre_projection, 1), :)) + return reshape(y, :, seq_len, batch_size) +end + +using Flux, Functors, Test, NNlib, MLUtils + +mha = MultiHeadAttention(64, 8) +sz = (64, 100, 32) +x = rand(Float32, sz) +y = mha(x) +@test y isa Array{Float32, 3} +@test size(y) == sz From 40c706d44b3553ff9703547634b6db9d4fba8e97 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 30 Dec 2022 17:27:15 +0100 Subject: [PATCH 02/21] generic attention --- Project.toml | 4 + src/layers/attention.jl | 212 ++++++++++++++++++++++++++++++---------- 2 files changed, 166 insertions(+), 50 deletions(-) diff --git a/Project.toml b/Project.toml index c0a027a551..14f0805593 100644 --- a/Project.toml +++ b/Project.toml @@ -5,13 +5,16 @@ version = "0.13.14" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" +NeuralAttentionlib = "12afc1b8-fad6-47e1-9132-84abc478905f" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Preferences = "21216c6a-2e73-6563-6e65-726566657250" @@ -22,6 +25,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] diff --git a/src/layers/attention.jl b/src/layers/attention.jl index dc7dac18f7..5b32016dbf 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -1,63 +1,175 @@ +using Flux, Test, LinearAlgebra, Random, Statistics +using CUDA, CUDAKernels, LoopVectorization +using Tullio +using NeuralAttentionlib +using BenchmarkTools + +const A3{T} = AbstractArray{T, 3} + """ - MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, - attn_dropout_prob = 0., proj_dropout_prob = 0.) + MultiHeadAttention(dims, num_heads; + [bias, init, attn_dropout_prob, proj_dropout_prob]) -Multi-head self-attention layer. +Multi-head dot-product attention layer. # Arguments -- `planes`: number of input channels +- `dims`: ... - `nheads`: number of heads -- `qkv_bias`: whether to use bias in the layer to get the query, key and value +- `init`: weight initializer for the Dense layers. +- `bias` : whether pointwise QKVO dense transforms use bias. - `attn_dropout_prob`: dropout probability after the self-attention layer - `proj_dropout_prob`: dropout probability after the projection layer + +# Forward + +- `in_q`: input tensor of shape `(batch_size, seq_len, dims) +- `in_k`: input tensor of shape `(batch_size, seq_len, dims) +- `in_v`: input tensor of shape `(batch_size, seq_len, dims) +- `mask`: input tensor of shape `(batch_size, seq_len, seq_len)` +- `return_weights`: whether to return the attention weights + +# Examples + +```julia +mha = MultiHeadAttention(64, 8) +``` """ -struct MultiHeadAttention{P, Q, R} - nheads::Int - qkv_layer::P - attn_drop::Q - projection::R -end - -@functor MHAttention - -function MultiHeadAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, - attn_dropout_prob = 0.0, proj_dropout_prob = 0.0) - @assert planes % nheads==0 "planes should be divisible by nheads" - qkv_layer = Dense(planes, planes * 3; bias = qkv_bias) - attn_drop = Dropout(attn_dropout_prob) - proj = Chain(Dense(planes, planes), Dropout(proj_dropout_prob)) - return MultiHeadAttention(nheads, qkv_layer, attn_drop, proj) -end - -function (m::MultiHeadAttention)(x::AbstractArray{T, 3}) where {T} - nfeatures, seq_len, batch_size = size(x) - x_reshaped = reshape(x, nfeatures, seq_len * batch_size) - qkv = m.qkv_layer(x_reshaped) - qkv_reshaped = reshape(qkv, nfeatures ÷ m.nheads, m.nheads, seq_len, 3 * batch_size) - query, key, value = chunk(qkv_reshaped, 3; dims = 4) - scale = convert(T, sqrt(size(query, 1) / m.nheads)) - key_reshaped = reshape(permutedims(key, (2, 1, 3, 4)), m.nheads, nfeatures ÷ m.nheads, - seq_len * batch_size) - query_reshaped = reshape(permutedims(query, (1, 2, 3, 4)), nfeatures ÷ m.nheads, - m.nheads, seq_len * batch_size) - - attention = softmax(batched_mul(query_reshaped, key_reshaped) .* scale) - attention = m.attn_drop(attention) +struct MultiHeadAttention + num_heads::Int + qkv_proj + attn_drop + out_proj +end + +@functor MultiHeadAttention + +function MultiHeadAttention(dims, num_heads::Int; + bias::Bool = false, + # init = glorot_uniform, # TODO + attn_dropout_prob = 0.0, + out_proj_dropout_prob = 0.0) + + dims = mha_process_dims(dims) + @assert dims.qkv % num_heads == 0 "qkv_dim should be divisible by num_heads" + qkv_proj = QKVProj((dims.q_in, dims.k_in, dims.v_in) => dims.qkv; bias) + attn_drop = Dropout(attn_dropout_prob) + out_proj = Chain(Dense(dims.qkv => dims.out; bias), Dropout(out_proj_dropout_prob)) + return MultiHeadAttention(num_heads, qkv_proj, attn_drop, out_proj) +end + +mha_process_dims(dims::Int) = (; q_in = dims, k_in = dims, v_in = dims, qkv = dims, out = dims) +mha_process_dims((in, (qkv, out))::Pair{Int, <:Pair}) = (; q_in = in, k_in = in, v_in = in, qkv, out) +mha_process_dims((in, (qkv, out))::Pair{<:Tuple, <:Pair}) = (; q_in = in[1], k_in = in[2], v_in = in[3], qkv, out) + +# self-attention +(m::MultiHeadAttention)(x; kws...) = m(x, x, x; kws...) + +function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; with_weights=false, v=:tullio) + ## [q_in] = [q_in_dim, q_len, batch_size] + ## [k_in] = [k_in_dim, kv_len, batch_size] + ## [v_in] = [v_in_dim, kv_len, batch_size] + + if v == :tullio + q, k, v = m.qkv_proj(q_in, k_in, v_in, m.num_heads) + # [q] = [qkv_dim / num_heads, num_heads, q_len, batch_size] + # [k] = [v] = [qkv_dim / num_heads, num_heads, kv_len, batch_size] - value_reshaped = reshape(permutedims(value, (1, 2, 3, 4)), nfeatures ÷ m.nheads, - m.nheads, seq_len * batch_size) - pre_projection = reshape(batched_mul(attention, value_reshaped), - (nfeatures, seq_len, batch_size)) - y = m.projection(reshape(pre_projection, size(pre_projection, 1), :)) - return reshape(y, :, seq_len, batch_size) + x, α = dot_product_attention(q, k, v; dropout=m.attn_drop) + x = reshape(x, :, size(x, 3), size(x, 4)) + elseif v == :nnalib + q, k, v = m.qkv_proj(q_in, k_in, v_in) + x = NeuralAttentionlib.multihead_qkv_attention(m.num_heads, q, k, v) + else + error("Unknown attention implementation") + end + + x = m.out_proj(x) + + return x + # return with_weights ? (x, α) : x end -using Flux, Functors, Test, NNlib, MLUtils +# Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html?highlight=dot_product_attention +function dot_product_attention(q, k, v; dropout=nothing) + α = dot_product_attention_weights(q, k; dropout) + # [α] = [kv_len, q_len, num_heads, batch_size] + @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b] + # [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size] + + return x, α +end -mha = MultiHeadAttention(64, 8) -sz = (64, 100, 32) -x = rand(Float32, sz) -y = mha(x) -@test y isa Array{Float32, 3} -@test size(y) == sz +function dot_product_attention_weights(q, k; dropout=nothing) + @tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] + # [α] = [kv_len, q_len, num_heads, batch_size] + α = softmax(α, dims=1) + return dropout === nothing ? α : dropout(α) +end + + +struct QKVProj + k_proj::Dense + v_proj::Dense + q_proj::Dense +end + +@functor QKVProj + +function QKVProj((in_dim, qkv_dim)::Pair; bias = false) + q_in_dim, k_in_dim, v_in_dim = in_dim + return QKVProj( + Dense(k_in_dim => qkv_dim; bias), + Dense(v_in_dim => qkv_dim; bias), + Dense(q_in_dim => qkv_dim; bias) + ) +end + +function (proj::QKVProj)(q_in, k_in, v_in, num_heads) + q = proj.q_proj(q_in) + sz = size(q) + newsz = (sz[1] ÷ num_heads, num_heads, sz[2:end]...) + q = reshape(q, newsz) + k = reshape(proj.k_proj(k_in), newsz) + v = reshape(proj.v_proj(v_in), newsz) + return q, k, v +end + +function (proj::QKVProj)(q_in, k_in, v_in) + return (proj.q_proj(q_in), proj.k_proj(k_in), proj.v_proj(v_in)) +end + + +function perf(dim, len, batch_size, num_heads) + mha = MultiHeadAttention(dim, num_heads) + x = rand(Float32, (dim, len, batch_size)) + + y = mha(x, x, x) + @test y isa Array{Float32, 3} + @test size(y) == (dim, len, batch_size) + + + println("tullio") + @btime $mha($x, v=:tullio); + @btime gradient(m -> sum(m($x, v=:tullio)), $mha); + + println("nnalib") + @btime $mha($x, $x, $x, v=:nnalib); + @btime gradient(m -> sum(m($x, v=:nnalib)), $mha); + + if CUDA.functional() + mha_gpu = mha |> gpu + x_gpu = x |> gpu + + println("tullio - gpu") + @btime $mha_gpu($x_gpu, v=:tullio); + @btime gradient(m -> sum(m($x_gpu, v=:tullio)), $mha_gpu); + + println("nnalib - gpu") + @btime CUDA.@sync $mha_gpu($x_gpu, v=:nnalib); + @btime CUDA.@sync gradient(m -> sum(m($x_gpu, v=:nnalib)), $mha_gpu); + end + return nothing +end + +perf(64, 100, 32, 8) From 59edf23c1a8300288535ebe9298ec1555754d229 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Fri, 30 Dec 2022 18:31:42 +0100 Subject: [PATCH 03/21] [ci skip] updates --- Project.toml | 1 + src/layers/attention.jl | 38 +++++++++++++++++++++++++++++--------- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 14f0805593..4941e1b01a 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" diff --git a/src/layers/attention.jl b/src/layers/attention.jl index 5b32016dbf..bbd4f5a3b9 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -1,9 +1,9 @@ -using Flux, Test, LinearAlgebra, Random, Statistics -using CUDA, CUDAKernels, LoopVectorization +using Flux, Functors, Test, LinearAlgebra, Random, Statistics +using CUDA, CUDAKernels, KernelAbstractions, LoopVectorization using Tullio using NeuralAttentionlib using BenchmarkTools - +CUDA.allowscalar(false) const A3{T} = AbstractArray{T, 3} """ @@ -144,11 +144,6 @@ function perf(dim, len, batch_size, num_heads) mha = MultiHeadAttention(dim, num_heads) x = rand(Float32, (dim, len, batch_size)) - y = mha(x, x, x) - @test y isa Array{Float32, 3} - @test size(y) == (dim, len, batch_size) - - println("tullio") @btime $mha($x, v=:tullio); @btime gradient(m -> sum(m($x, v=:tullio)), $mha); @@ -172,4 +167,29 @@ function perf(dim, len, batch_size, num_heads) return nothing end -perf(64, 100, 32, 8) +function test(dim, len, batch_size, num_heads) + mha = MultiHeadAttention(dim, num_heads) + x = rand(Float32, (dim, len, batch_size)) + y = mha(x, v=:tullio) + @test y isa Array{Float32, 3} + @test size(y) == (dim, len, batch_size) + y2 = mha(x, v=:nnalib) + @test size(y) == size(y2) + @test y2 ≈ y + + if CUDA.functional() + mha_gpu = mha |> gpu + x_gpu = x |> gpu + + y_gpu = mha_gpu(x_gpu, v=:tullio) + y_gpu2 = mha_gpu(x_gpu, v=:nnalib) + @test Array(y_gpu) ≈ Array(y_gpu2) + @test Array(y_gpu) ≈ y + end + return nothing +end + + +test(12, 3, 2, 4) + +perf(64, 100, 32, 4) From 364695628905bfa28851342cb79854c63df9b3ec Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sat, 31 Dec 2022 11:01:47 +0100 Subject: [PATCH 04/21] [ci skip] updates --- src/layers/attention.jl | 126 ++++++++++++++++++++++++---------------- 1 file changed, 76 insertions(+), 50 deletions(-) diff --git a/src/layers/attention.jl b/src/layers/attention.jl index bbd4f5a3b9..958d03867a 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -1,10 +1,12 @@ using Flux, Functors, Test, LinearAlgebra, Random, Statistics -using CUDA, CUDAKernels, KernelAbstractions, LoopVectorization -using Tullio +using CUDA +using CUDAKernels, KernelAbstractions, LoopVectorization, Tullio using NeuralAttentionlib using BenchmarkTools CUDA.allowscalar(false) + const A3{T} = AbstractArray{T, 3} +const A4{T} = AbstractArray{T, 4} """ MultiHeadAttention(dims, num_heads; @@ -48,7 +50,8 @@ function MultiHeadAttention(dims, num_heads::Int; bias::Bool = false, # init = glorot_uniform, # TODO attn_dropout_prob = 0.0, - out_proj_dropout_prob = 0.0) + out_proj_dropout_prob = 0.0, + self=false) dims = mha_process_dims(dims) @assert dims.qkv % num_heads == 0 "qkv_dim should be divisible by num_heads" @@ -58,48 +61,59 @@ function MultiHeadAttention(dims, num_heads::Int; return MultiHeadAttention(num_heads, qkv_proj, attn_drop, out_proj) end -mha_process_dims(dims::Int) = (; q_in = dims, k_in = dims, v_in = dims, qkv = dims, out = dims) -mha_process_dims((in, (qkv, out))::Pair{Int, <:Pair}) = (; q_in = in, k_in = in, v_in = in, qkv, out) -mha_process_dims((in, (qkv, out))::Pair{<:Tuple, <:Pair}) = (; q_in = in[1], k_in = in[2], v_in = in[3], qkv, out) +mha_process_dims(dims::Int) = + (; q_in = dims, k_in = dims, v_in = dims, qkv = dims, out = dims) + +mha_process_dims((in, (qkv, out))::Pair{Int, <:Pair{Int, Int}}) = + (; q_in = in, k_in = in, v_in = in, qkv, out) + +mha_process_dims((in, (qkv, out))::Pair{<:Tuple, <:Pair{Int, Int}}) = + (; q_in = in[1], k_in = in[2], v_in = in[3], qkv, out) # self-attention (m::MultiHeadAttention)(x; kws...) = m(x, x, x; kws...) -function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; with_weights=false, v=:tullio) +function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; with_weights=false, impl=:tullio) ## [q_in] = [q_in_dim, q_len, batch_size] ## [k_in] = [k_in_dim, kv_len, batch_size] ## [v_in] = [v_in_dim, kv_len, batch_size] - if v == :tullio - q, k, v = m.qkv_proj(q_in, k_in, v_in, m.num_heads) - # [q] = [qkv_dim / num_heads, num_heads, q_len, batch_size] - # [k] = [v] = [qkv_dim / num_heads, num_heads, kv_len, batch_size] - - x, α = dot_product_attention(q, k, v; dropout=m.attn_drop) - x = reshape(x, :, size(x, 3), size(x, 4)) - elseif v == :nnalib - q, k, v = m.qkv_proj(q_in, k_in, v_in) - x = NeuralAttentionlib.multihead_qkv_attention(m.num_heads, q, k, v) + q, k, v = m.qkv_proj(q_in, k_in, v_in) + # [q] = [qkv_dim, q_len, batch_size] + # [k] = [v] = [qkv_dim, kv_len, batch_size] + if impl == :tullio + x, α = dot_product_attention(m.num_heads, q, k, v; dropout=m.attn_drop) + elseif impl == :nnalib + x, α = NeuralAttentionlib.multihead_qkv_attention( + NeuralAttentionlib.score_returning, + m.num_heads, q, k, v) else error("Unknown attention implementation") end x = m.out_proj(x) - return x - # return with_weights ? (x, α) : x + return with_weights ? (x, α) : x end -# Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html?highlight=dot_product_attention -function dot_product_attention(q, k, v; dropout=nothing) +# Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html +function dot_product_attention(q::A4, k::A4, v::A4; dropout=nothing) α = dot_product_attention_weights(q, k; dropout) # [α] = [kv_len, q_len, num_heads, batch_size] @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b] # [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size] - return x, α end +function dot_product_attention(num_heads::Int, q::A3, k::A3, v::A3; kws...) + q, k, v = reshape_heads.((q, k, v), num_heads) + x, α = dot_product_attention(q, k, v; kws...) + return flatten_heads(x), α +end + +reshape_heads(x, num_heads) = reshape(x, size(x, 1) ÷ num_heads, num_heads, size(x)[2:end]...) +flatten_heads(x) = reshape(x, :, size(x)[3:end]...) + function dot_product_attention_weights(q, k; dropout=nothing) @tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] # [α] = [kv_len, q_len, num_heads, batch_size] @@ -125,16 +139,6 @@ function QKVProj((in_dim, qkv_dim)::Pair; bias = false) ) end -function (proj::QKVProj)(q_in, k_in, v_in, num_heads) - q = proj.q_proj(q_in) - sz = size(q) - newsz = (sz[1] ÷ num_heads, num_heads, sz[2:end]...) - q = reshape(q, newsz) - k = reshape(proj.k_proj(k_in), newsz) - v = reshape(proj.v_proj(v_in), newsz) - return q, k, v -end - function (proj::QKVProj)(q_in, k_in, v_in) return (proj.q_proj(q_in), proj.k_proj(k_in), proj.v_proj(v_in)) end @@ -145,51 +149,73 @@ function perf(dim, len, batch_size, num_heads) x = rand(Float32, (dim, len, batch_size)) println("tullio") - @btime $mha($x, v=:tullio); - @btime gradient(m -> sum(m($x, v=:tullio)), $mha); + @btime $mha($x, impl=:tullio); + @btime gradient(m -> sum(m($x, impl=:tullio)), $mha); println("nnalib") - @btime $mha($x, $x, $x, v=:nnalib); - @btime gradient(m -> sum(m($x, v=:nnalib)), $mha); + @btime $mha($x, $x, $x, impl=:nnalib); + @btime gradient(m -> sum(m($x, impl=:nnalib)), $mha); if CUDA.functional() mha_gpu = mha |> gpu x_gpu = x |> gpu println("tullio - gpu") - @btime $mha_gpu($x_gpu, v=:tullio); - @btime gradient(m -> sum(m($x_gpu, v=:tullio)), $mha_gpu); + @btime $mha_gpu($x_gpu, impl=:tullio); + @btime gradient(m -> sum(m($x_gpu, impl=:tullio)), $mha_gpu); println("nnalib - gpu") - @btime CUDA.@sync $mha_gpu($x_gpu, v=:nnalib); - @btime CUDA.@sync gradient(m -> sum(m($x_gpu, v=:nnalib)), $mha_gpu); + @btime CUDA.@sync $mha_gpu($x_gpu, impl=:nnalib); + @btime CUDA.@sync gradient(m -> sum(m($x_gpu, impl=:nnalib)), $mha_gpu); end return nothing end -function test(dim, len, batch_size, num_heads) +function test(dim, num_heads, len, batch_size) mha = MultiHeadAttention(dim, num_heads) x = rand(Float32, (dim, len, batch_size)) - y = mha(x, v=:tullio) + y, α = mha(x, impl=:tullio, with_weights=true) @test y isa Array{Float32, 3} @test size(y) == (dim, len, batch_size) - y2 = mha(x, v=:nnalib) + @test α isa Array{Float32, 4} + @test size(α) == (len, len, num_heads, batch_size) + + y2, α2 = mha(x, impl=:nnalib, with_weights=true) @test size(y) == size(y2) - @test y2 ≈ y + @test y2 ≈ y atol=1e-1 + @test size(α) == size(α2) + @test α2 ≈ α atol=1e-1 if CUDA.functional() mha_gpu = mha |> gpu x_gpu = x |> gpu - y_gpu = mha_gpu(x_gpu, v=:tullio) - y_gpu2 = mha_gpu(x_gpu, v=:nnalib) - @test Array(y_gpu) ≈ Array(y_gpu2) + y_gpu = mha_gpu(x_gpu, impl=:tullio) + y_gpu2 = mha_gpu(x_gpu, impl=:nnalib) + @test Array(y_gpu) ≈ Array(y_gpu2) atol=1e-1 @test Array(y_gpu) ≈ y end return nothing end -test(12, 3, 2, 4) - -perf(64, 100, 32, 4) +test(4, 2, 2, 1) + +perf(128, 8, 128, 32) +# tullio +# 5.862 ms (85 allocations: 6.75 MiB) +# 14.291 ms (1046 allocations: 17.17 MiB) +# nnalib +# 6.331 ms (90 allocations: 7.75 MiB) +# 16.186 ms (690 allocations: 16.17 MiB) +# tullio - gpu +# 141.365 μs (499 allocations: 22.81 KiB) +# 804.018 μs (2228 allocations: 113.45 KiB) +# nnalib - gpu +# 163.487 μs (410 allocations: 18.02 KiB) +# 673.463 μs (1521 allocations: 84.64 KiB) + +dim = 4; num_heads=2; len=2; batch_size=1 +mha = MultiHeadAttention(dim, num_heads) +x = rand(Float32, (dim, len, batch_size)) +y, α = mha(x, impl=:tullio, with_weights=true) From 26442830279fbdebeb30250b4ac62680cea7d29b Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sun, 1 Jan 2023 06:52:37 +0100 Subject: [PATCH 05/21] [ci skip] fix tullio impl --- src/layers/attention.jl | 90 ++++++++++++++++++++++++----------------- test_jax.py | 17 ++++++++ 2 files changed, 69 insertions(+), 38 deletions(-) create mode 100644 test_jax.py diff --git a/src/layers/attention.jl b/src/layers/attention.jl index 958d03867a..70b7222579 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -3,6 +3,7 @@ using CUDA using CUDAKernels, KernelAbstractions, LoopVectorization, Tullio using NeuralAttentionlib using BenchmarkTools +using Flux: glorot_uniform CUDA.allowscalar(false) const A3{T} = AbstractArray{T, 3} @@ -48,27 +49,42 @@ end function MultiHeadAttention(dims, num_heads::Int; bias::Bool = false, - # init = glorot_uniform, # TODO + init = glorot_uniform, attn_dropout_prob = 0.0, - out_proj_dropout_prob = 0.0, - self=false) + out_proj_dropout_prob = 0.0) dims = mha_process_dims(dims) - @assert dims.qkv % num_heads == 0 "qkv_dim should be divisible by num_heads" - qkv_proj = QKVProj((dims.q_in, dims.k_in, dims.v_in) => dims.qkv; bias) + @assert dims.qk % num_heads == 0 "qk_dim should be divisible by num_heads" + qkv_proj = QKVProj(dims; bias, init) attn_drop = Dropout(attn_dropout_prob) - out_proj = Chain(Dense(dims.qkv => dims.out; bias), Dropout(out_proj_dropout_prob)) + out_proj = Chain(Dense(dims.v => dims.out; bias, init), Dropout(out_proj_dropout_prob)) return MultiHeadAttention(num_heads, qkv_proj, attn_drop, out_proj) end +# The following inputs are equivalent: +# 8 +# 8 => 8 => 8 +# (8, 8, 8) => 8 => 8 +# 8 => (8, 8) => 8 +# (8, 8, 8) => (8, 8) => 8 # (q_in, k_in, v_in) => (qk, v) => out mha_process_dims(dims::Int) = - (; q_in = dims, k_in = dims, v_in = dims, qkv = dims, out = dims) + (; q_in = dims, k_in = dims, v_in = dims, + qk = dims, v = dims, out = dims) mha_process_dims((in, (qkv, out))::Pair{Int, <:Pair{Int, Int}}) = - (; q_in = in, k_in = in, v_in = in, qkv, out) + (; q_in = in, k_in = in, v_in = in, + qk = qkv, v = qkv, out) mha_process_dims((in, (qkv, out))::Pair{<:Tuple, <:Pair{Int, Int}}) = - (; q_in = in[1], k_in = in[2], v_in = in[3], qkv, out) + (; q_in = in[1], k_in = in[2], v_in = in[3], + qk = qkv, v = qkv, out) + +mha_process_dims((in, ((qk, v), out))::Pair{<:Tuple, <:Pair{<:Tuple, Int}}) = + (; q_in = in[1], k_in = in[2], v_in = in[3], qk, v, out) + +mha_process_dims((in, ((qk, v), out))::Pair{Int, <:Pair{<:Tuple, Int}}) = + (; q_in = in, k_in = in, v_in = in, qk, v, out) + # self-attention (m::MultiHeadAttention)(x; kws...) = m(x, x, x; kws...) @@ -79,11 +95,13 @@ function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; with_weights=fals ## [v_in] = [v_in_dim, kv_len, batch_size] q, k, v = m.qkv_proj(q_in, k_in, v_in) - # [q] = [qkv_dim, q_len, batch_size] - # [k] = [v] = [qkv_dim, kv_len, batch_size] + # [q] = [qk_dim, q_len, batch_size] + # [k] = [qk_dim, kv_len, batch_size] + # [v] = [v_dim, kv_len, batch_size] + if impl == :tullio x, α = dot_product_attention(m.num_heads, q, k, v; dropout=m.attn_drop) - elseif impl == :nnalib + elseif impl == :nalib x, α = NeuralAttentionlib.multihead_qkv_attention( NeuralAttentionlib.score_returning, m.num_heads, q, k, v) @@ -114,7 +132,9 @@ end reshape_heads(x, num_heads) = reshape(x, size(x, 1) ÷ num_heads, num_heads, size(x)[2:end]...) flatten_heads(x) = reshape(x, :, size(x)[3:end]...) -function dot_product_attention_weights(q, k; dropout=nothing) +function dot_product_attention_weights(q::A4{T}, k::A4{T}; + dropout=nothing) where T + q = q ./ T(√size(q, 1)) @tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] # [α] = [kv_len, q_len, num_heads, batch_size] α = softmax(α, dims=1) @@ -123,20 +143,19 @@ end struct QKVProj + q_proj::Dense k_proj::Dense v_proj::Dense - q_proj::Dense end @functor QKVProj -function QKVProj((in_dim, qkv_dim)::Pair; bias = false) - q_in_dim, k_in_dim, v_in_dim = in_dim +function QKVProj(dims; bias = false, init=glorot_uniform) return QKVProj( - Dense(k_in_dim => qkv_dim; bias), - Dense(v_in_dim => qkv_dim; bias), - Dense(q_in_dim => qkv_dim; bias) - ) + Dense(dims.q_in => dims.qk; bias, init), + Dense(dims.k_in => dims.qk; bias, init), + Dense(dims.v_in => dims.v; bias, init) + ) end function (proj::QKVProj)(q_in, k_in, v_in) @@ -152,9 +171,9 @@ function perf(dim, len, batch_size, num_heads) @btime $mha($x, impl=:tullio); @btime gradient(m -> sum(m($x, impl=:tullio)), $mha); - println("nnalib") - @btime $mha($x, $x, $x, impl=:nnalib); - @btime gradient(m -> sum(m($x, impl=:nnalib)), $mha); + println("nalib") + @btime $mha($x, $x, $x, impl=:nalib); + @btime gradient(m -> sum(m($x, impl=:nalib)), $mha); if CUDA.functional() mha_gpu = mha |> gpu @@ -164,9 +183,9 @@ function perf(dim, len, batch_size, num_heads) @btime $mha_gpu($x_gpu, impl=:tullio); @btime gradient(m -> sum(m($x_gpu, impl=:tullio)), $mha_gpu); - println("nnalib - gpu") - @btime CUDA.@sync $mha_gpu($x_gpu, impl=:nnalib); - @btime CUDA.@sync gradient(m -> sum(m($x_gpu, impl=:nnalib)), $mha_gpu); + println("nalib - gpu") + @btime CUDA.@sync $mha_gpu($x_gpu, impl=:nalib); + @btime CUDA.@sync gradient(m -> sum(m($x_gpu, impl=:nalib)), $mha_gpu); end return nothing end @@ -180,19 +199,19 @@ function test(dim, num_heads, len, batch_size) @test α isa Array{Float32, 4} @test size(α) == (len, len, num_heads, batch_size) - y2, α2 = mha(x, impl=:nnalib, with_weights=true) + y2, α2 = mha(x, impl=:nalib, with_weights=true) @test size(y) == size(y2) - @test y2 ≈ y atol=1e-1 + @test y2 ≈ y @test size(α) == size(α2) - @test α2 ≈ α atol=1e-1 + @test α2 ≈ α if CUDA.functional() mha_gpu = mha |> gpu x_gpu = x |> gpu y_gpu = mha_gpu(x_gpu, impl=:tullio) - y_gpu2 = mha_gpu(x_gpu, impl=:nnalib) - @test Array(y_gpu) ≈ Array(y_gpu2) atol=1e-1 + y_gpu2 = mha_gpu(x_gpu, impl=:nalib) + @test Array(y_gpu) ≈ Array(y_gpu2) @test Array(y_gpu) ≈ y end return nothing @@ -205,17 +224,12 @@ perf(128, 8, 128, 32) # tullio # 5.862 ms (85 allocations: 6.75 MiB) # 14.291 ms (1046 allocations: 17.17 MiB) -# nnalib +# nalib # 6.331 ms (90 allocations: 7.75 MiB) # 16.186 ms (690 allocations: 16.17 MiB) # tullio - gpu # 141.365 μs (499 allocations: 22.81 KiB) # 804.018 μs (2228 allocations: 113.45 KiB) -# nnalib - gpu +# nalib - gpu # 163.487 μs (410 allocations: 18.02 KiB) # 673.463 μs (1521 allocations: 84.64 KiB) - -dim = 4; num_heads=2; len=2; batch_size=1 -mha = MultiHeadAttention(dim, num_heads) -x = rand(Float32, (dim, len, batch_size)) -y, α = mha(x, impl=:tullio, with_weights=true) diff --git a/test_jax.py b/test_jax.py new file mode 100644 index 0000000000..9db5c91f65 --- /dev/null +++ b/test_jax.py @@ -0,0 +1,17 @@ +#%% +import jax +import jax.numpy as jnp # JAX NumPy + +from flax import linen as nn # The Linen API + +#import numpy as np # Ordinary NumPy +#import optax # Optimizers +#import tensorflow_datasets as tfds # TFDS for MNIST +# %% +x = jnp.arange(16).reshape(1,2,2,4) / 16 +y = nn.dot_product_attention(x, x, x) +yt = y.transpose((3,2,1,0)) + +yt +yt.shape +# %% From 0c61daaac1ec7e806b9c973ec3025578f48fd7bc Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sun, 1 Jan 2023 12:52:11 +0100 Subject: [PATCH 06/21] causal mask --- src/layers/attention.jl | 126 +++++++++++++++++++++++++--------------- 1 file changed, 80 insertions(+), 46 deletions(-) diff --git a/src/layers/attention.jl b/src/layers/attention.jl index 70b7222579..e109594abb 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -2,8 +2,10 @@ using Flux, Functors, Test, LinearAlgebra, Random, Statistics using CUDA using CUDAKernels, KernelAbstractions, LoopVectorization, Tullio using NeuralAttentionlib +using NeuralAttentionlib: score_returning using BenchmarkTools using Flux: glorot_uniform +using MLUtils CUDA.allowscalar(false) const A3{T} = AbstractArray{T, 3} @@ -18,19 +20,22 @@ Multi-head dot-product attention layer. # Arguments - `dims`: ... -- `nheads`: number of heads +- `num_heads`: number of heads. - `init`: weight initializer for the Dense layers. - `bias` : whether pointwise QKVO dense transforms use bias. - `attn_dropout_prob`: dropout probability after the self-attention layer - `proj_dropout_prob`: dropout probability after the projection layer # Forward + + (::MultiHeadAttention)(q_in, k_in, v_in; [mask, with_weights]) -- `in_q`: input tensor of shape `(batch_size, seq_len, dims) -- `in_k`: input tensor of shape `(batch_size, seq_len, dims) -- `in_v`: input tensor of shape `(batch_size, seq_len, dims) -- `mask`: input tensor of shape `(batch_size, seq_len, seq_len)` -- `return_weights`: whether to return the attention weights +- `q_in`: input array of size `( seq_len, dims) +- `k_in`: input array of size `( seq_len, dims) +- `v_in`: input array of size `( seq_len, dims) +- `mask`: input array broadcastable to size + `(kv_len, q_len, num_heads, batch_size)`. Default `nothing`. +- `with_weights`: Whether to return the attention weights. Default `false`. # Examples @@ -68,28 +73,33 @@ end # 8 => (8, 8) => 8 # (8, 8, 8) => (8, 8) => 8 # (q_in, k_in, v_in) => (qk, v) => out mha_process_dims(dims::Int) = - (; q_in = dims, k_in = dims, v_in = dims, - qk = dims, v = dims, out = dims) + (; q_in=dims, k_in=dims, v_in=dims, qk=dims, v=dims, out=dims) -mha_process_dims((in, (qkv, out))::Pair{Int, <:Pair{Int, Int}}) = - (; q_in = in, k_in = in, v_in = in, - qk = qkv, v = qkv, out) - -mha_process_dims((in, (qkv, out))::Pair{<:Tuple, <:Pair{Int, Int}}) = - (; q_in = in[1], k_in = in[2], v_in = in[3], - qk = qkv, v = qkv, out) - -mha_process_dims((in, ((qk, v), out))::Pair{<:Tuple, <:Pair{<:Tuple, Int}}) = - (; q_in = in[1], k_in = in[2], v_in = in[3], qk, v, out) - -mha_process_dims((in, ((qk, v), out))::Pair{Int, <:Pair{<:Tuple, Int}}) = - (; q_in = in, k_in = in, v_in = in, qk, v, out) +const TuplInt2 = Union{Int, Tuple{Int, Int}} +const TuplInt3 = Union{Int, Tuple{Int, Int, Int}} +function mha_process_dims((in, (qkv, out))::Pair{<:TuplInt3, <:Pair{<:TuplInt2, Int}}) + if in isa Int + q_in = k_in = v_in = in + else + q_in, k_in, v_in = in + end + if qkv isa Int + qk = v = qkv + else + qk, v = qkv + end + return (; q_in, k_in, v_in, qk, v, out) +end # self-attention -(m::MultiHeadAttention)(x; kws...) = m(x, x, x; kws...) +(m::MultiHeadAttention)(qkv; kws...) = m(qkv, qkv, qkv; kws...) + +# key and value are the same +(m::MultiHeadAttention)(q, kv; kws...) = m(q, kv, kv; kws...) -function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; with_weights=false, impl=:tullio) +function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; + with_weights=false, mask=nothing, impl=:tullio) ## [q_in] = [q_in_dim, q_len, batch_size] ## [k_in] = [k_in_dim, kv_len, batch_size] ## [v_in] = [v_in_dim, kv_len, batch_size] @@ -100,11 +110,9 @@ function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; with_weights=fals # [v] = [v_dim, kv_len, batch_size] if impl == :tullio - x, α = dot_product_attention(m.num_heads, q, k, v; dropout=m.attn_drop) + x, α = dot_product_attention(m.num_heads, q, k, v; mask, dropout=m.attn_drop) elseif impl == :nalib - x, α = NeuralAttentionlib.multihead_qkv_attention( - NeuralAttentionlib.score_returning, - m.num_heads, q, k, v) + x, α = NeuralAttentionlib.multihead_qkv_attention(score_returning, m.num_heads, q, k, v) else error("Unknown attention implementation") end @@ -114,14 +122,8 @@ function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; with_weights=fals return with_weights ? (x, α) : x end -# Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html -function dot_product_attention(q::A4, k::A4, v::A4; dropout=nothing) - α = dot_product_attention_weights(q, k; dropout) - # [α] = [kv_len, q_len, num_heads, batch_size] - @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b] - # [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size] - return x, α -end +reshape_heads(x, num_heads) = reshape(x, size(x, 1) ÷ num_heads, num_heads, size(x)[2:end]...) +flatten_heads(x) = reshape(x, :, size(x)[3:end]...) function dot_product_attention(num_heads::Int, q::A3, k::A3, v::A3; kws...) q, k, v = reshape_heads.((q, k, v), num_heads) @@ -129,14 +131,32 @@ function dot_product_attention(num_heads::Int, q::A3, k::A3, v::A3; kws...) return flatten_heads(x), α end -reshape_heads(x, num_heads) = reshape(x, size(x, 1) ÷ num_heads, num_heads, size(x)[2:end]...) -flatten_heads(x) = reshape(x, :, size(x)[3:end]...) +# Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html +function dot_product_attention(q::A4, k::A4, v::A4; + dropout=nothing, bias=nothing, mask=nothing) + + α = dot_product_attention_weights(q, k; dropout, bias, mask) + # [α] = [kv_len, q_len, num_heads, batch_size] + @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b] + # [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size] + return x, α +end function dot_product_attention_weights(q::A4{T}, k::A4{T}; - dropout=nothing) where T + dropout=nothing, mask=nothing, bias=nothing) where T + q = q ./ T(√size(q, 1)) @tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] # [α] = [kv_len, q_len, num_heads, batch_size] + + if bias !== nothing + α = α .+ bias + end + if mask !== nothing + neginf = typemin(eltype(α)) + α = ifelse.(mask, α, neginf) + end + α = softmax(α, dims=1) return dropout === nothing ? α : dropout(α) end @@ -162,6 +182,13 @@ function (proj::QKVProj)(q_in, k_in, v_in) return (proj.q_proj(q_in), proj.k_proj(k_in), proj.v_proj(v_in)) end +function make_causal_mask(x::A3) + d, len, batch_size = size(x) + mask = tril(ones_like(x, (len, len))) + return mask +end + +@non_differentiable make_causal_mask(x) function perf(dim, len, batch_size, num_heads) mha = MultiHeadAttention(dim, num_heads) @@ -222,14 +249,21 @@ test(4, 2, 2, 1) perf(128, 8, 128, 32) # tullio -# 5.862 ms (85 allocations: 6.75 MiB) -# 14.291 ms (1046 allocations: 17.17 MiB) +# 5.475 ms (80 allocations: 7.25 MiB) +# 13.073 ms (1172 allocations: 18.18 MiB) +# tullio - 6 threads +# 4.818 ms (192 allocations: 7.26 MiB) +# 10.927 ms (1398 allocations: 18.19 MiB) # nalib -# 6.331 ms (90 allocations: 7.75 MiB) -# 16.186 ms (690 allocations: 16.17 MiB) +# 6.040 ms (91 allocations: 7.75 MiB) +# 14.542 ms (696 allocations: 16.17 MiB) +# nalib - 6 threads +# 7.832 ms (187 allocations: 7.76 MiB) +# 29.823 ms (988 allocations: 16.19 MiB) # tullio - gpu -# 141.365 μs (499 allocations: 22.81 KiB) -# 804.018 μs (2228 allocations: 113.45 KiB) +# 147.746 μs (523 allocations: 24.59 KiB) +# 957.111 μs (2413 allocations: 127.88 KiB) # nalib - gpu -# 163.487 μs (410 allocations: 18.02 KiB) -# 673.463 μs (1521 allocations: 84.64 KiB) +# 165.109 μs (411 allocations: 18.05 KiB) +# 659.685 μs (1527 allocations: 86.09 KiB) + From 6e7f5389e70608563eda9bb73d7a3561030b40d3 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sun, 1 Jan 2023 18:09:10 +0100 Subject: [PATCH 07/21] [ci skip] mask --- src/layers/attention.jl | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/layers/attention.jl b/src/layers/attention.jl index e109594abb..b6be7f4cc6 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -6,6 +6,7 @@ using NeuralAttentionlib: score_returning using BenchmarkTools using Flux: glorot_uniform using MLUtils +using ChainRulesCore CUDA.allowscalar(false) const A3{T} = AbstractArray{T, 3} @@ -112,7 +113,7 @@ function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; if impl == :tullio x, α = dot_product_attention(m.num_heads, q, k, v; mask, dropout=m.attn_drop) elseif impl == :nalib - x, α = NeuralAttentionlib.multihead_qkv_attention(score_returning, m.num_heads, q, k, v) + x, α = NeuralAttentionlib.multihead_qkv_attention(score_returning, m.num_heads, q, k, v, mask) else error("Unknown attention implementation") end @@ -184,11 +185,16 @@ end function make_causal_mask(x::A3) d, len, batch_size = size(x) - mask = tril(ones_like(x, (len, len))) + mask = triu(trues_like(x, (len, len))) return mask end +trues_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), true) +falses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false) + @non_differentiable make_causal_mask(x) +@non_differentiable trues_like(::Any...) +@non_differentiable falses_like(::Any...) function perf(dim, len, batch_size, num_heads) mha = MultiHeadAttention(dim, num_heads) @@ -231,7 +237,13 @@ function test(dim, num_heads, len, batch_size) @test y2 ≈ y @test size(α) == size(α2) @test α2 ≈ α - + + mask = make_causal_mask(x) + y3, α3 = mha(x; impl=:tullio, with_weights=true, mask) + y4, α4 = mha(x, impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask()) + @test y ≈ y2 + @test α ≈ α2 + if CUDA.functional() mha_gpu = mha |> gpu x_gpu = x |> gpu @@ -244,8 +256,7 @@ function test(dim, num_heads, len, batch_size) return nothing end - -test(4, 2, 2, 1) +test(4, 2, 3, 1) perf(128, 8, 128, 32) # tullio @@ -267,3 +278,9 @@ perf(128, 8, 128, 32) # 165.109 μs (411 allocations: 18.05 KiB) # 659.685 μs (1527 allocations: 86.09 KiB) +dim = 2; len = 3; batch_size = 1; num_heads = 1 +mha = MultiHeadAttention(dim, num_heads) +x = rand(Float32, (dim, len, batch_size)) +mask = make_causal_mask(x) +y, α = mha(x; impl=:tullio, with_weights=true, mask) +y2, α2 = mha(x; impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask()) From e212b6b3d544322d4589e1d6a9413442f23ce1bf Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sun, 1 Jan 2023 22:09:24 +0100 Subject: [PATCH 08/21] [ci skip] updates --- src/layers/attention.jl | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/src/layers/attention.jl b/src/layers/attention.jl index b6be7f4cc6..c521353aca 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -146,7 +146,7 @@ end function dot_product_attention_weights(q::A4{T}, k::A4{T}; dropout=nothing, mask=nothing, bias=nothing) where T - q = q ./ T(√size(q, 1)) + q = q ./ √T(size(q, 1)) @tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] # [α] = [kv_len, q_len, num_heads, batch_size] @@ -173,10 +173,9 @@ end function QKVProj(dims; bias = false, init=glorot_uniform) return QKVProj( - Dense(dims.q_in => dims.qk; bias, init), - Dense(dims.k_in => dims.qk; bias, init), - Dense(dims.v_in => dims.v; bias, init) - ) + Dense(dims.q_in => dims.qk; bias, init), + Dense(dims.k_in => dims.qk; bias, init), + Dense(dims.v_in => dims.v; bias, init)) end function (proj::QKVProj)(q_in, k_in, v_in) @@ -224,32 +223,35 @@ function perf(dim, len, batch_size, num_heads) end function test(dim, num_heads, len, batch_size) - mha = MultiHeadAttention(dim, num_heads) - x = rand(Float32, (dim, len, batch_size)) - y, α = mha(x, impl=:tullio, with_weights=true) + mha = MultiHeadAttention(dim, num_heads) + q = rand(Float32, (dim, len, batch_size)) + k = rand(Float32, (dim, len, batch_size)) + v = rand(Float32, (dim, len, batch_size)) + + y, α = mha(q, k, v, impl=:tullio, with_weights=true) @test y isa Array{Float32, 3} @test size(y) == (dim, len, batch_size) @test α isa Array{Float32, 4} @test size(α) == (len, len, num_heads, batch_size) - y2, α2 = mha(x, impl=:nalib, with_weights=true) + y2, α2 = mha(q, k, v, impl=:nalib, with_weights=true) @test size(y) == size(y2) @test y2 ≈ y @test size(α) == size(α2) @test α2 ≈ α - mask = make_causal_mask(x) - y3, α3 = mha(x; impl=:tullio, with_weights=true, mask) - y4, α4 = mha(x, impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask()) - @test y ≈ y2 - @test α ≈ α2 + mask = make_causal_mask(q) + y3, α3 = mha(q, k, v; impl=:tullio, with_weights=true, mask) + y4, α4 = mha(q, k, v, impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask()) + @test y3 ≈ y4 + @test α3 ≈ α4 if CUDA.functional() mha_gpu = mha |> gpu - x_gpu = x |> gpu - - y_gpu = mha_gpu(x_gpu, impl=:tullio) - y_gpu2 = mha_gpu(x_gpu, impl=:nalib) + q_gpu, k_gpu, v_gpu = q |> gpu, k |> gpu, v |> gpu + + y_gpu = mha_gpu(q_gpu, k_gpu, v_gpu, impl=:tullio) + y_gpu2 = mha_gpu(q_gpu, k_gpu, v_gpu, impl=:nalib) @test Array(y_gpu) ≈ Array(y_gpu2) @test Array(y_gpu) ≈ y end From 2df4d5e3714c34122751c0b8cbb77f4409e02e8c Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Mon, 2 Jan 2023 00:35:40 +0100 Subject: [PATCH 09/21] [ci skip] add native implementation --- src/layers/attention.jl | 97 +++++++++++++++++++++++++++++++++++------ test_jax.py | 1 + 2 files changed, 85 insertions(+), 13 deletions(-) diff --git a/src/layers/attention.jl b/src/layers/attention.jl index c521353aca..11fbdd312c 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -111,9 +111,11 @@ function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; # [v] = [v_dim, kv_len, batch_size] if impl == :tullio - x, α = dot_product_attention(m.num_heads, q, k, v; mask, dropout=m.attn_drop) + x, α = dot_product_attention_tullio(m.num_heads, q, k, v; mask, dropout=m.attn_drop) elseif impl == :nalib x, α = NeuralAttentionlib.multihead_qkv_attention(score_returning, m.num_heads, q, k, v, mask) + elseif impl == :native + x, α = dot_product_attention_native(m.num_heads, q, k, v; mask, dropout=m.attn_drop) else error("Unknown attention implementation") end @@ -126,24 +128,30 @@ end reshape_heads(x, num_heads) = reshape(x, size(x, 1) ÷ num_heads, num_heads, size(x)[2:end]...) flatten_heads(x) = reshape(x, :, size(x)[3:end]...) -function dot_product_attention(num_heads::Int, q::A3, k::A3, v::A3; kws...) +function dot_product_attention_tullio(num_heads::Int, q::A3, k::A3, v::A3; kws...) q, k, v = reshape_heads.((q, k, v), num_heads) - x, α = dot_product_attention(q, k, v; kws...) + x, α = dot_product_attention_tullio(q, k, v; kws...) + return flatten_heads(x), α +end + +function dot_product_attention_native(num_heads::Int, q::A3, k::A3, v::A3; kws...) + q, k, v = reshape_heads.((q, k, v), num_heads) + x, α = dot_product_attention_native(q, k, v; kws...) return flatten_heads(x), α end # Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html -function dot_product_attention(q::A4, k::A4, v::A4; +function dot_product_attention_tullio(q::A4, k::A4, v::A4; dropout=nothing, bias=nothing, mask=nothing) - α = dot_product_attention_weights(q, k; dropout, bias, mask) + α = dot_product_attention_weights_tullio(q, k; dropout, bias, mask) # [α] = [kv_len, q_len, num_heads, batch_size] @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b] # [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size] return x, α end -function dot_product_attention_weights(q::A4{T}, k::A4{T}; +function dot_product_attention_weights_tullio(q::A4{T}, k::A4{T}; dropout=nothing, mask=nothing, bias=nothing) where T q = q ./ √T(size(q, 1)) @@ -162,6 +170,49 @@ function dot_product_attention_weights(q::A4{T}, k::A4{T}; return dropout === nothing ? α : dropout(α) end +function NNlib.batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N} + sz = size(x)[3:end] + @assert sz == size(y)[3:end] + x2 = reshape(x, size(x, 1), size(x, 2), :) + y2 = reshape(y, size(y, 1), size(y, 2), :) + z = NNlib.batched_mul(x2, y2) + return reshape(z, size(z, 1), size(z, 2), sz...) +end + +function dot_product_attention_native(q::A4, k::A4, v::A4; + dropout=nothing, bias=nothing, mask=nothing) + + α = dot_product_attention_weights_native(q, k; dropout, bias, mask) + # [α] = [kv_len, q_len, num_heads, batch_size] + + vt = permutedims(v, (1, 3, 2, 4)) + x = NNlib.batched_mul(vt, α) + x = permutedims(x, (1, 3, 2, 4)) + # [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size] + return x, α +end + +function dot_product_attention_weights_native(q::A4{T}, k::A4{T}; + dropout=nothing, mask=nothing, bias=nothing) where T + + q = q ./ √T(size(q, 1)) + kt = permutedims(k, (3, 1, 2, 4)) + qt = permutedims(q, (1, 3, 2, 4)) + α = NNlib.batched_mul(kt, qt) + # [α] = [kv_len, q_len, num_heads, batch_size] + + if bias !== nothing + α = α .+ bias + end + if mask !== nothing + neginf = typemin(eltype(α)) + α = ifelse.(mask, α, neginf) + end + + α = softmax(α, dims=1) + return dropout === nothing ? α : dropout(α) +end + struct QKVProj q_proj::Dense @@ -206,6 +257,10 @@ function perf(dim, len, batch_size, num_heads) println("nalib") @btime $mha($x, $x, $x, impl=:nalib); @btime gradient(m -> sum(m($x, impl=:nalib)), $mha); + + println("native") + @btime $mha($x, $x, $x, impl=:native); + @btime gradient(m -> sum(m($x, impl=:native)), $mha); if CUDA.functional() mha_gpu = mha |> gpu @@ -218,6 +273,10 @@ function perf(dim, len, batch_size, num_heads) println("nalib - gpu") @btime CUDA.@sync $mha_gpu($x_gpu, impl=:nalib); @btime CUDA.@sync gradient(m -> sum(m($x_gpu, impl=:nalib)), $mha_gpu); + + println("native - gpu") + @btime CUDA.@sync $mha_gpu($x_gpu, impl=:native); + @btime CUDA.@sync gradient(m -> sum(m($x_gpu, impl=:native)), $mha_gpu); end return nothing end @@ -240,6 +299,12 @@ function test(dim, num_heads, len, batch_size) @test size(α) == size(α2) @test α2 ≈ α + y2b, α2b = mha(q, k, v, impl=:native, with_weights=true) + @test size(y) == size(y2b) + @test y2b ≈ y + @test size(α) == size(α2b) + @test α2b ≈ α + mask = make_causal_mask(q) y3, α3 = mha(q, k, v; impl=:tullio, with_weights=true, mask) y4, α4 = mha(q, k, v, impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask()) @@ -273,16 +338,22 @@ perf(128, 8, 128, 32) # nalib - 6 threads # 7.832 ms (187 allocations: 7.76 MiB) # 29.823 ms (988 allocations: 16.19 MiB) +# native +# 6.269 ms (90 allocations: 9.25 MiB) +# 15.492 ms (1250 allocations: 22.19 MiB) # tullio - gpu # 147.746 μs (523 allocations: 24.59 KiB) # 957.111 μs (2413 allocations: 127.88 KiB) # nalib - gpu # 165.109 μs (411 allocations: 18.05 KiB) # 659.685 μs (1527 allocations: 86.09 KiB) - -dim = 2; len = 3; batch_size = 1; num_heads = 1 -mha = MultiHeadAttention(dim, num_heads) -x = rand(Float32, (dim, len, batch_size)) -mask = make_causal_mask(x) -y, α = mha(x; impl=:tullio, with_weights=true, mask) -y2, α2 = mha(x; impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask()) +# native - gpu +# 158.396 μs (443 allocations: 20.06 KiB) +# 920.633 μs (2308 allocations: 118.78 KiB) + +# dim = 2; len = 3; batch_size = 1; num_heads = 1 +# mha = MultiHeadAttention(dim, num_heads) +# x = rand(Float32, (dim, len, batch_size)) +# mask = make_causal_mask(x) +# y, α = mha(x; impl=:tullio, with_weights=true, mask) +# y2, α2 = mha(x; impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask()) diff --git a/test_jax.py b/test_jax.py index 9db5c91f65..5957f8896f 100644 --- a/test_jax.py +++ b/test_jax.py @@ -9,6 +9,7 @@ #import tensorflow_datasets as tfds # TFDS for MNIST # %% x = jnp.arange(16).reshape(1,2,2,4) / 16 +alpha = nn.dot_product_attention_weights(x, x) y = nn.dot_product_attention(x, x, x) yt = y.transpose((3,2,1,0)) From 30b22d723ed713517a44062d85d9b5089b57e650 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Mon, 2 Jan 2023 10:56:44 +0100 Subject: [PATCH 10/21] support mask = :causal --- src/layers/attention.jl | 62 ++++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 26 deletions(-) diff --git a/src/layers/attention.jl b/src/layers/attention.jl index 11fbdd312c..cb2f1f63bd 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -44,11 +44,11 @@ Multi-head dot-product attention layer. mha = MultiHeadAttention(64, 8) ``` """ -struct MultiHeadAttention +struct MultiHeadAttention{P1, D, P2} num_heads::Int - qkv_proj - attn_drop - out_proj + qkv_proj::P1 + attn_drop::D + out_pro::P2 end @functor MultiHeadAttention @@ -56,14 +56,13 @@ end function MultiHeadAttention(dims, num_heads::Int; bias::Bool = false, init = glorot_uniform, - attn_dropout_prob = 0.0, - out_proj_dropout_prob = 0.0) + attn_dropout_prob = 0.0) dims = mha_process_dims(dims) @assert dims.qk % num_heads == 0 "qk_dim should be divisible by num_heads" qkv_proj = QKVProj(dims; bias, init) attn_drop = Dropout(attn_dropout_prob) - out_proj = Chain(Dense(dims.v => dims.out; bias, init), Dropout(out_proj_dropout_prob)) + out_proj = Dense(dims.v => dims.out; bias, init) return MultiHeadAttention(num_heads, qkv_proj, attn_drop, out_proj) end @@ -100,7 +99,7 @@ end (m::MultiHeadAttention)(q, kv; kws...) = m(q, kv, kv; kws...) function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; - with_weights=false, mask=nothing, impl=:tullio) + with_weights=false, mask=nothing, impl=:native) ## [q_in] = [q_in_dim, q_len, batch_size] ## [k_in] = [k_in_dim, kv_len, batch_size] ## [v_in] = [v_in_dim, kv_len, batch_size] @@ -115,7 +114,7 @@ function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; elseif impl == :nalib x, α = NeuralAttentionlib.multihead_qkv_attention(score_returning, m.num_heads, q, k, v, mask) elseif impl == :native - x, α = dot_product_attention_native(m.num_heads, q, k, v; mask, dropout=m.attn_drop) + x, α = dot_product_attention(m.num_heads, q, k, v; mask, dropout=m.attn_drop) else error("Unknown attention implementation") end @@ -134,11 +133,6 @@ function dot_product_attention_tullio(num_heads::Int, q::A3, k::A3, v::A3; kws.. return flatten_heads(x), α end -function dot_product_attention_native(num_heads::Int, q::A3, k::A3, v::A3; kws...) - q, k, v = reshape_heads.((q, k, v), num_heads) - x, α = dot_product_attention_native(q, k, v; kws...) - return flatten_heads(x), α -end # Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html function dot_product_attention_tullio(q::A4, k::A4, v::A4; @@ -179,12 +173,20 @@ function NNlib.batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where return reshape(z, size(z, 1), size(z, 2), sz...) end -function dot_product_attention_native(q::A4, k::A4, v::A4; +function dot_product_attention(num_heads::Int, q::A3, k::A3, v::A3; kws...) + q, k, v = reshape_heads.((q, k, v), num_heads) + x, α = dot_product_attention(q, k, v; kws...) + return flatten_heads(x), α +end + +function dot_product_attention(q::A4, k::A4, v::A4; dropout=nothing, bias=nothing, mask=nothing) - α = dot_product_attention_weights_native(q, k; dropout, bias, mask) + α = dot_product_attention_weights(q, k; dropout, bias, mask) # [α] = [kv_len, q_len, num_heads, batch_size] + # The following permutations and batched_mul are equivalent to + # @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b] vt = permutedims(v, (1, 3, 2, 4)) x = NNlib.batched_mul(vt, α) x = permutedims(x, (1, 3, 2, 4)) @@ -192,10 +194,13 @@ function dot_product_attention_native(q::A4, k::A4, v::A4; return x, α end -function dot_product_attention_weights_native(q::A4{T}, k::A4{T}; +function dot_product_attention_weights(q::A4{T}, k::A4{T}; dropout=nothing, mask=nothing, bias=nothing) where T q = q ./ √T(size(q, 1)) + + # The following permutations and batched_mul are equivalent to + # @tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] kt = permutedims(k, (3, 1, 2, 4)) qt = permutedims(q, (1, 3, 2, 4)) α = NNlib.batched_mul(kt, qt) @@ -204,7 +209,11 @@ function dot_product_attention_weights_native(q::A4{T}, k::A4{T}; if bias !== nothing α = α .+ bias end + if mask !== nothing + if mask === :causal + mask = make_causal_mask(α) + end neginf = typemin(eltype(α)) α = ifelse.(mask, α, neginf) end @@ -329,15 +338,9 @@ perf(128, 8, 128, 32) # tullio # 5.475 ms (80 allocations: 7.25 MiB) # 13.073 ms (1172 allocations: 18.18 MiB) -# tullio - 6 threads -# 4.818 ms (192 allocations: 7.26 MiB) -# 10.927 ms (1398 allocations: 18.19 MiB) # nalib # 6.040 ms (91 allocations: 7.75 MiB) # 14.542 ms (696 allocations: 16.17 MiB) -# nalib - 6 threads -# 7.832 ms (187 allocations: 7.76 MiB) -# 29.823 ms (988 allocations: 16.19 MiB) # native # 6.269 ms (90 allocations: 9.25 MiB) # 15.492 ms (1250 allocations: 22.19 MiB) @@ -351,9 +354,16 @@ perf(128, 8, 128, 32) # 158.396 μs (443 allocations: 20.06 KiB) # 920.633 μs (2308 allocations: 118.78 KiB) -# dim = 2; len = 3; batch_size = 1; num_heads = 1 +# perf(384, 12, 256, 32) + + +# dim, len, batch_size, num_heads = 128, 8, 128, 32; +# # dim = 384; len = 128; batch_size = 32; num_heads = 12 # mha = MultiHeadAttention(dim, num_heads) # x = rand(Float32, (dim, len, batch_size)) -# mask = make_causal_mask(x) -# y, α = mha(x; impl=:tullio, with_weights=true, mask) +# @btime mha(x, impl=:tullio); +# @btime mha(x, impl=:native); +# @profview mha(x, impl=:tullio); +# @profview [mha(x, impl=:native) for _ in 1:100]; +# y, α = mha(x; impl=:native, with_weights=true, mask) # y2, α2 = mha(x; impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask()) From 38b8bdf7decba3f28a1bc4359374b005a97c493f Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 5 Jan 2023 08:54:29 +0100 Subject: [PATCH 11/21] [ci skip] factor out impl --- src/layers/attention.jl | 274 ++++++++++----------------------- src/layers/attention_nnlib.jl | 151 ++++++++++++++++++ src/layers/attention_tullio.jl | 41 +++++ 3 files changed, 271 insertions(+), 195 deletions(-) create mode 100644 src/layers/attention_nnlib.jl create mode 100644 src/layers/attention_tullio.jl diff --git a/src/layers/attention.jl b/src/layers/attention.jl index cb2f1f63bd..4e226d33bd 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -1,42 +1,43 @@ using Flux, Functors, Test, LinearAlgebra, Random, Statistics using CUDA -using CUDAKernels, KernelAbstractions, LoopVectorization, Tullio using NeuralAttentionlib using NeuralAttentionlib: score_returning using BenchmarkTools using Flux: glorot_uniform -using MLUtils -using ChainRulesCore CUDA.allowscalar(false) const A3{T} = AbstractArray{T, 3} const A4{T} = AbstractArray{T, 4} +const TuplInt2 = Union{Int, Tuple{Int, Int}} +const TuplInt3 = Union{Int, Tuple{Int, Int, Int}} + +include("attention_nnlib.jl") +include("attention_tullio.jl") + """ - MultiHeadAttention(dims, num_heads; - [bias, init, attn_dropout_prob, proj_dropout_prob]) + MultiHeadAttention(dims, nheads; [bias, init, dropout_prob]) Multi-head dot-product attention layer. # Arguments - `dims`: ... -- `num_heads`: number of heads. +- `nheads`: number of heads. - `init`: weight initializer for the Dense layers. - `bias` : whether pointwise QKVO dense transforms use bias. -- `attn_dropout_prob`: dropout probability after the self-attention layer -- `proj_dropout_prob`: dropout probability after the projection layer +- `dropout_prob`: dropout probability for the attention scores. # Forward - (::MultiHeadAttention)(q_in, k_in, v_in; [mask, with_weights]) + (::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask, withscores]) -- `q_in`: input array of size `( seq_len, dims) -- `k_in`: input array of size `( seq_len, dims) -- `v_in`: input array of size `( seq_len, dims) +- `q_in`: input query array of size `(q_in_dim, q_len, batch_size...)`. +- `k_in`: input key array of size `(k_in_dim, kv_len, batch_size...)`. +- `v_in`: input value array of size `(v_in_dim, kv_len, batch_size...)`. - `mask`: input array broadcastable to size - `(kv_len, q_len, num_heads, batch_size)`. Default `nothing`. -- `with_weights`: Whether to return the attention weights. Default `false`. + `(kv_len, q_len, nheads, batch_size)`. Default `nothing`. +- `withscores`: Whether to return the attention scores. Default `false`. # Examples @@ -45,39 +46,30 @@ mha = MultiHeadAttention(64, 8) ``` """ struct MultiHeadAttention{P1, D, P2} - num_heads::Int + nheads::Int qkv_proj::P1 attn_drop::D - out_pro::P2 + out_proj::P2 end @functor MultiHeadAttention -function MultiHeadAttention(dims, num_heads::Int; +function MultiHeadAttention(dims, nheads::Int; bias::Bool = false, init = glorot_uniform, - attn_dropout_prob = 0.0) + dropout_prob = 0.0) dims = mha_process_dims(dims) - @assert dims.qk % num_heads == 0 "qk_dim should be divisible by num_heads" + @assert dims.qk % nheads == 0 "qk_dim should be divisible by nheads" qkv_proj = QKVProj(dims; bias, init) - attn_drop = Dropout(attn_dropout_prob) + attn_drop = Dropout(dropout_prob) out_proj = Dense(dims.v => dims.out; bias, init) - return MultiHeadAttention(num_heads, qkv_proj, attn_drop, out_proj) + return MultiHeadAttention(nheads, qkv_proj, attn_drop, out_proj) end -# The following inputs are equivalent: -# 8 -# 8 => 8 => 8 -# (8, 8, 8) => 8 => 8 -# 8 => (8, 8) => 8 -# (8, 8, 8) => (8, 8) => 8 # (q_in, k_in, v_in) => (qk, v) => out mha_process_dims(dims::Int) = (; q_in=dims, k_in=dims, v_in=dims, qk=dims, v=dims, out=dims) -const TuplInt2 = Union{Int, Tuple{Int, Int}} -const TuplInt3 = Union{Int, Tuple{Int, Int, Int}} - function mha_process_dims((in, (qkv, out))::Pair{<:TuplInt3, <:Pair{<:TuplInt2, Int}}) if in isa Int q_in = k_in = v_in = in @@ -98,8 +90,8 @@ end # key and value are the same (m::MultiHeadAttention)(q, kv; kws...) = m(q, kv, kv; kws...) -function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; - with_weights=false, mask=nothing, impl=:native) +function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3, bias=nothing; + withscores=false, mask=nothing, impl=:nnlib) ## [q_in] = [q_in_dim, q_len, batch_size] ## [k_in] = [k_in_dim, kv_len, batch_size] ## [v_in] = [v_in_dim, kv_len, batch_size] @@ -110,119 +102,20 @@ function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3; # [v] = [v_dim, kv_len, batch_size] if impl == :tullio - x, α = dot_product_attention_tullio(m.num_heads, q, k, v; mask, dropout=m.attn_drop) + x, α = dot_product_attention_tullio(m.nheads, q, k, v; mask, dropout=m.attn_drop) elseif impl == :nalib - x, α = NeuralAttentionlib.multihead_qkv_attention(score_returning, m.num_heads, q, k, v, mask) - elseif impl == :native - x, α = dot_product_attention(m.num_heads, q, k, v; mask, dropout=m.attn_drop) + x, α = NeuralAttentionlib.multihead_qkv_attention(score_returning, m.nheads, q, k, v, mask) + elseif impl == :nnlib + x, α = dot_product_attention(q, k, v, bias; m.nheads, mask, fdrop=m.attn_drop) else error("Unknown attention implementation") end x = m.out_proj(x) - return with_weights ? (x, α) : x -end - -reshape_heads(x, num_heads) = reshape(x, size(x, 1) ÷ num_heads, num_heads, size(x)[2:end]...) -flatten_heads(x) = reshape(x, :, size(x)[3:end]...) - -function dot_product_attention_tullio(num_heads::Int, q::A3, k::A3, v::A3; kws...) - q, k, v = reshape_heads.((q, k, v), num_heads) - x, α = dot_product_attention_tullio(q, k, v; kws...) - return flatten_heads(x), α -end - - -# Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html -function dot_product_attention_tullio(q::A4, k::A4, v::A4; - dropout=nothing, bias=nothing, mask=nothing) - - α = dot_product_attention_weights_tullio(q, k; dropout, bias, mask) - # [α] = [kv_len, q_len, num_heads, batch_size] - @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b] - # [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size] - return x, α -end - -function dot_product_attention_weights_tullio(q::A4{T}, k::A4{T}; - dropout=nothing, mask=nothing, bias=nothing) where T - - q = q ./ √T(size(q, 1)) - @tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] - # [α] = [kv_len, q_len, num_heads, batch_size] - - if bias !== nothing - α = α .+ bias - end - if mask !== nothing - neginf = typemin(eltype(α)) - α = ifelse.(mask, α, neginf) - end - - α = softmax(α, dims=1) - return dropout === nothing ? α : dropout(α) -end - -function NNlib.batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N} - sz = size(x)[3:end] - @assert sz == size(y)[3:end] - x2 = reshape(x, size(x, 1), size(x, 2), :) - y2 = reshape(y, size(y, 1), size(y, 2), :) - z = NNlib.batched_mul(x2, y2) - return reshape(z, size(z, 1), size(z, 2), sz...) + return withscores ? (x, α) : x end -function dot_product_attention(num_heads::Int, q::A3, k::A3, v::A3; kws...) - q, k, v = reshape_heads.((q, k, v), num_heads) - x, α = dot_product_attention(q, k, v; kws...) - return flatten_heads(x), α -end - -function dot_product_attention(q::A4, k::A4, v::A4; - dropout=nothing, bias=nothing, mask=nothing) - - α = dot_product_attention_weights(q, k; dropout, bias, mask) - # [α] = [kv_len, q_len, num_heads, batch_size] - - # The following permutations and batched_mul are equivalent to - # @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b] - vt = permutedims(v, (1, 3, 2, 4)) - x = NNlib.batched_mul(vt, α) - x = permutedims(x, (1, 3, 2, 4)) - # [x] = [kv_dim ÷ num_heads, num_heads, q_len, batch_size] - return x, α -end - -function dot_product_attention_weights(q::A4{T}, k::A4{T}; - dropout=nothing, mask=nothing, bias=nothing) where T - - q = q ./ √T(size(q, 1)) - - # The following permutations and batched_mul are equivalent to - # @tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] - kt = permutedims(k, (3, 1, 2, 4)) - qt = permutedims(q, (1, 3, 2, 4)) - α = NNlib.batched_mul(kt, qt) - # [α] = [kv_len, q_len, num_heads, batch_size] - - if bias !== nothing - α = α .+ bias - end - - if mask !== nothing - if mask === :causal - mask = make_causal_mask(α) - end - neginf = typemin(eltype(α)) - α = ifelse.(mask, α, neginf) - end - - α = softmax(α, dims=1) - return dropout === nothing ? α : dropout(α) -end - - struct QKVProj q_proj::Dense k_proj::Dense @@ -242,21 +135,8 @@ function (proj::QKVProj)(q_in, k_in, v_in) return (proj.q_proj(q_in), proj.k_proj(k_in), proj.v_proj(v_in)) end -function make_causal_mask(x::A3) - d, len, batch_size = size(x) - mask = triu(trues_like(x, (len, len))) - return mask -end - -trues_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), true) -falses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false) - -@non_differentiable make_causal_mask(x) -@non_differentiable trues_like(::Any...) -@non_differentiable falses_like(::Any...) - -function perf(dim, len, batch_size, num_heads) - mha = MultiHeadAttention(dim, num_heads) +function perf(dim, len, batch_size, nheads) + mha = MultiHeadAttention(dim, nheads) x = rand(Float32, (dim, len, batch_size)) println("tullio") @@ -267,9 +147,9 @@ function perf(dim, len, batch_size, num_heads) @btime $mha($x, $x, $x, impl=:nalib); @btime gradient(m -> sum(m($x, impl=:nalib)), $mha); - println("native") - @btime $mha($x, $x, $x, impl=:native); - @btime gradient(m -> sum(m($x, impl=:native)), $mha); + println("nnlib") + @btime $mha($x, $x, $x, impl=:nnlib); + @btime gradient(m -> sum(m($x, impl=:nnlib)), $mha); if CUDA.functional() mha_gpu = mha |> gpu @@ -283,40 +163,40 @@ function perf(dim, len, batch_size, num_heads) @btime CUDA.@sync $mha_gpu($x_gpu, impl=:nalib); @btime CUDA.@sync gradient(m -> sum(m($x_gpu, impl=:nalib)), $mha_gpu); - println("native - gpu") - @btime CUDA.@sync $mha_gpu($x_gpu, impl=:native); - @btime CUDA.@sync gradient(m -> sum(m($x_gpu, impl=:native)), $mha_gpu); + println("nnlib - gpu") + @btime CUDA.@sync $mha_gpu($x_gpu, impl=:nnlib); + @btime CUDA.@sync gradient(m -> sum(m($x_gpu, impl=:nnlib)), $mha_gpu); end return nothing end -function test(dim, num_heads, len, batch_size) - mha = MultiHeadAttention(dim, num_heads) +function test(dim, nheads, len, batch_size) + mha = MultiHeadAttention(dim, nheads) q = rand(Float32, (dim, len, batch_size)) k = rand(Float32, (dim, len, batch_size)) v = rand(Float32, (dim, len, batch_size)) - y, α = mha(q, k, v, impl=:tullio, with_weights=true) + y, α = mha(q, k, v, impl=:tullio, withscores=true) @test y isa Array{Float32, 3} @test size(y) == (dim, len, batch_size) @test α isa Array{Float32, 4} - @test size(α) == (len, len, num_heads, batch_size) + @test size(α) == (len, len, nheads, batch_size) - y2, α2 = mha(q, k, v, impl=:nalib, with_weights=true) + y2, α2 = mha(q, k, v, impl=:nalib, withscores=true) @test size(y) == size(y2) @test y2 ≈ y @test size(α) == size(α2) @test α2 ≈ α - y2b, α2b = mha(q, k, v, impl=:native, with_weights=true) + y2b, α2b = mha(q, k, v, impl=:nnlib, withscores=true) @test size(y) == size(y2b) @test y2b ≈ y @test size(α) == size(α2b) @test α2b ≈ α mask = make_causal_mask(q) - y3, α3 = mha(q, k, v; impl=:tullio, with_weights=true, mask) - y4, α4 = mha(q, k, v, impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask()) + y3, α3 = mha(q, k, v; impl=:tullio, withscores=true, mask) + y4, α4 = mha(q, k, v, impl=:nalib, withscores=true, mask=NeuralAttentionlib.CausalMask()) @test y3 ≈ y4 @test α3 ≈ α4 @@ -335,35 +215,39 @@ end test(4, 2, 3, 1) perf(128, 8, 128, 32) + +## M1 Pro, NNlib v0.8.12 +# tullio +# 2.948 ms (77 allocations: 7.25 MiB) +# 15.041 ms (1124 allocations: 16.71 MiB) +# nalib +# 3.503 ms (89 allocations: 7.75 MiB) +# 15.828 ms (604 allocations: 14.70 MiB) +# nnlib +# 3.611 ms (87 allocations: 9.25 MiB) +# 16.497 ms (1055 allocations: 20.71 MiB) + +## M1 Pro, NNlib v0.8.13 (fast_maximum) # tullio -# 5.475 ms (80 allocations: 7.25 MiB) -# 13.073 ms (1172 allocations: 18.18 MiB) +# 2.427 ms (71 allocations: 7.13 MiB) +# 14.510 ms (1118 allocations: 16.59 MiB) # nalib -# 6.040 ms (91 allocations: 7.75 MiB) -# 14.542 ms (696 allocations: 16.17 MiB) -# native -# 6.269 ms (90 allocations: 9.25 MiB) -# 15.492 ms (1250 allocations: 22.19 MiB) -# tullio - gpu -# 147.746 μs (523 allocations: 24.59 KiB) -# 957.111 μs (2413 allocations: 127.88 KiB) -# nalib - gpu -# 165.109 μs (411 allocations: 18.05 KiB) -# 659.685 μs (1527 allocations: 86.09 KiB) -# native - gpu -# 158.396 μs (443 allocations: 20.06 KiB) -# 920.633 μs (2308 allocations: 118.78 KiB) - -# perf(384, 12, 256, 32) - - -# dim, len, batch_size, num_heads = 128, 8, 128, 32; -# # dim = 384; len = 128; batch_size = 32; num_heads = 12 -# mha = MultiHeadAttention(dim, num_heads) -# x = rand(Float32, (dim, len, batch_size)) -# @btime mha(x, impl=:tullio); -# @btime mha(x, impl=:native); -# @profview mha(x, impl=:tullio); -# @profview [mha(x, impl=:native) for _ in 1:100]; -# y, α = mha(x; impl=:native, with_weights=true, mask) -# y2, α2 = mha(x; impl=:nalib, with_weights=true, mask=NeuralAttentionlib.CausalMask()) +# 3.052 ms (84 allocations: 7.63 MiB) +# 15.327 ms (599 allocations: 14.57 MiB) +# nnlib +# 3.166 ms (81 allocations: 9.13 MiB) +# 16.082 ms (1049 allocations: 20.58 MiB) + + +# function prof() + # dim, len, batch_size, nheads = 128, 8, 128, 32; + # # dim = 384; len = 128; batch_size = 32; nheads = 12 + # mha = MultiHeadAttention(dim, nheads) + # x = rand(Float32, (dim, len, batch_size)) + # @btime mha(x, impl=:tullio); + # @btime mha(x, impl=:nnlib); + # @profview mha(x, impl=:tullio); + # @profview prof(mha, x); + # y, α = mha(x; impl=:nnlib, withscores=true, mask) + # y2, α2 = mha(x; impl=:nalib, withscores=true, mask=NeuralAttentionlib.CausalMask()) +# end \ No newline at end of file diff --git a/src/layers/attention_nnlib.jl b/src/layers/attention_nnlib.jl new file mode 100644 index 0000000000..3c95439b03 --- /dev/null +++ b/src/layers/attention_nnlib.jl @@ -0,0 +1,151 @@ +using ChainRulesCore +using NNlib + +const AA3{T} = AbstractArray{T,3} +const AA4{T} = AbstractArray{T,4} +const AA{N,T} = AbstractArray{T,N} + +""" + dot_product_attention(query, key, value, [bias]; [fdrop, mask, nheads]) + +Multihead dot product attention used in transformer architectures. + +The input arrays must have the first two dimensions given by the number of features +and the sequece length, then an arbitrary number of batch dimensions or none. + +Returns the attention output array of size `(v_dim, q_len, batch_size...)` and the attention scores. +of size `(kv_len, q_len, nheads, batch_size...)`. + +See also [`dot_product_attention_scores`](@ref) if you only need the attention scores. + +# Arguments + +- `query`: Query array of size `(qk_dim, q_len, batch_size...)`. +- `key`: Key array of size `(qk_dim, kv_len, batch_size...)`. +- `value`: Value array of size `(v_dim, kv_len, batch_size...)`. +- `bias`: Either `nothing` or an array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. + It will be added to the attention scores before applying the softmax. Default `nothing`. +- `fdrop`: A dropout function or layer to be applied on the attention scores right after the softmax. + Default `identity` (no dropout). +- `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. + The mask is applied to the attention scores before the softmax. + Can also be set to `mask=:causal` to apply a causal mask. Default `nothing`. +- `nheads`: Number of heads to split the input arrays into. Default `1`. + +# Examples + +```julia +q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2) +y, α = dot_product_attention(q, k, v) +``` +""" +function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}, args...; kws...) where N + batch_size = size(q)[3:end] + batch_size == size(k)[3:end] == size(v)[3:end] || throw(ArgumentError("Batch dimensions have to be the same.")) + q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v)) + + x, α = dot_product_attention(q, k, v, args...; kws...) + + x = reshape(x, size(x, 1), size(x, 2), batch_size...) + α = reshape(α, size(α)[1:3]..., batch_size...) + return x, α +end + +function dot_product_attention(q::AA3, k::AA3, v::AA3, bias=nothing; + fdrop=identity, mask=nothing, nheads=1) + + (size(q, 3) == size(k, 3) == size(v, 3)) || throw(ArgumentError("Batch dimensions have to be the same.")) + size(q, 1) == size(k, 1) || throw(ArgumentError("First dimension in query and key has to be the same.")) + size(k, 2) == size(v, 2) || throw(ArgumentError("Second dimension in key and value has to be the same.")) + + # Multihead attention. TODO create fastpath for singlehead attention. + q, k, v = split_heads.((q, k, v), nheads) + x, α = _dot_product_attention(q, k, v, bias, fdrop, mask) + return join_heads(x), α +end + +function _dot_product_attention(q::AA4, k::AA4, v::AA4, bias, fdrop, mask) + # [q] = [qk_dim ÷ nheads, nheads, q_len, batch_size] + # [k] = [qk_dim ÷ nheads, nheads, kv_len, batch_size] + # [v] = [v_dim ÷ nheads, nheads, kv_len, batch_size] + + α = dot_product_attention_scores(q, k, bias; fdrop, mask) + # [α] = [kv_len, q_len, nheads, batch_size] + + # The following permutedims and batched_mul are equivalent to + # @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b] + vt = permutedims(v, (1, 3, 2, 4)) + x = batched_mul(vt, α) + x = permutedims(x, (1, 3, 2, 4)) + # [x] = [kv_dim ÷ nheads, nheads, q_len, batch_size] + return x, α +end + +""" + dot_product_attention_scores(query, key, [bias]; [fdrop, mask]) + +Return the attention scores for the [`dot_product_attention`](@ref). +Input arrays must have dimensions +`(num_features ÷ nheads, nheads, sequence_length, batch_size)`. + +See [`dot_product_attention`](@ref) for more details. +""" +function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing; + fdrop=identity, mask=nothing) where T + + # The following permutedims and batched_mul are equivalent to + # @tullio logits[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] / √T(qk_dim) + kt = permutedims(k, (3, 1, 2, 4)) + qt = permutedims(q, (1, 3, 2, 4)) ./ √T(size(q, 1)) + logits = batched_mul(kt, qt) + # [logits] = [kv_len, q_len, nheads, batch_size] + + if bias !== nothing + logits = logits .+ bias + end + + if mask !== nothing + if mask === :causal + mask = make_causal_mask(logits) + end + neginf = typemin(eltype(logits)) + logits = ifelse.(mask, logits, neginf) + end + + α = softmax(logits, dims=1) + return fdrop(α) +end + +""" + make_causal_mask(x, dims=2) + +Return a boolean square matrix `m` of the same type as `x` and of side `size(x, dims)`. +Its elements are set such that `m[i, j] == i ≤ j`. + +Can be used to mask the attention scores in [`dot_product_attention`](@ref). +""" +function make_causal_mask(x::AbstractArray; dims::Int=2) + len = size(x, dims) + mask = triu(trues_like(x, (len, len))) + return mask +end + +trues_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), true) +falses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false) + +split_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...) +join_heads(x) = reshape(x, :, size(x)[3:end]...) + +@non_differentiable make_causal_mask(x) +@non_differentiable trues_like(::Any...) +@non_differentiable falses_like(::Any...) + +function NNlib.batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N} + batch_size = size(x)[3:end] + @assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays." + x2 = reshape(x, size(x, 1), size(x, 2), :) + y2 = reshape(y, size(y, 1), size(y, 2), :) + z = batched_mul(x2, y2) + return reshape(z, size(z, 1), size(z, 2), batch_size...) + end + diff --git a/src/layers/attention_tullio.jl b/src/layers/attention_tullio.jl new file mode 100644 index 0000000000..6d8a2ad6ec --- /dev/null +++ b/src/layers/attention_tullio.jl @@ -0,0 +1,41 @@ +using CUDAKernels, KernelAbstractions, LoopVectorization, Tullio + +reshape_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...) +flatten_heads(x) = reshape(x, :, size(x)[3:end]...) + +function dot_product_attention_tullio(nheads::Int, q::A3, k::A3, v::A3; kws...) + q, k, v = reshape_heads.((q, k, v), nheads) + x, α = dot_product_attention_tullio(q, k, v; kws...) + return flatten_heads(x), α +end + + +# Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html +function dot_product_attention_tullio(q::A4, k::A4, v::A4; + dropout=nothing, bias=nothing, mask=nothing) + + α = dot_product_attention_weights_tullio(q, k; dropout, bias, mask) + # [α] = [kv_len, q_len, nheads, batch_size] + @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b] + # [x] = [kv_dim ÷ nheads, nheads, q_len, batch_size] + return x, α +end + +function dot_product_attention_weights_tullio(q::A4{T}, k::A4{T}; + dropout=nothing, mask=nothing, bias=nothing) where T + + q = q ./ √T(size(q, 1)) + @tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] + # [α] = [kv_len, q_len, nheads, batch_size] + + if bias !== nothing + α = α .+ bias + end + if mask !== nothing + neginf = typemin(eltype(α)) + α = ifelse.(mask, α, neginf) + end + + α = softmax(α, dims=1) + return dropout === nothing ? α : dropout(α) +end From 19fe8d9352203a44c28975303e56494a63e74d2b Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 5 Jan 2023 08:55:32 +0100 Subject: [PATCH 12/21] [ci skip] remove jax --- test_jax.py | 18 ------------------ 1 file changed, 18 deletions(-) delete mode 100644 test_jax.py diff --git a/test_jax.py b/test_jax.py deleted file mode 100644 index 5957f8896f..0000000000 --- a/test_jax.py +++ /dev/null @@ -1,18 +0,0 @@ -#%% -import jax -import jax.numpy as jnp # JAX NumPy - -from flax import linen as nn # The Linen API - -#import numpy as np # Ordinary NumPy -#import optax # Optimizers -#import tensorflow_datasets as tfds # TFDS for MNIST -# %% -x = jnp.arange(16).reshape(1,2,2,4) / 16 -alpha = nn.dot_product_attention_weights(x, x) -y = nn.dot_product_attention(x, x, x) -yt = y.transpose((3,2,1,0)) - -yt -yt.shape -# %% From 2b9b219bdc8c587f6c0a4fa257798815bc44feb9 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 5 Jan 2023 09:11:25 +0100 Subject: [PATCH 13/21] [ci skip] more benchs --- src/layers/attention.jl | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/src/layers/attention.jl b/src/layers/attention.jl index 4e226d33bd..cb68928636 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -238,6 +238,46 @@ perf(128, 8, 128, 32) # 3.166 ms (81 allocations: 9.13 MiB) # 16.082 ms (1049 allocations: 20.58 MiB) +## Threadripper, NNlib v0.8.12 +# tullio +# 5.658 ms (77 allocations: 7.25 MiB) +# 22.373 ms (1124 allocations: 16.71 MiB) +# nalib +# 6.187 ms (89 allocations: 7.75 MiB) +# 23.723 ms (604 allocations: 14.70 MiB) +# nnlib +# 6.473 ms (87 allocations: 9.25 MiB) +# 24.966 ms (1055 allocations: 20.71 MiB) +# tullio - gpu +# 145.332 μs (520 allocations: 24.52 KiB) +# 902.020 μs (2221 allocations: 117.19 KiB) +# nalib - gpu +# 162.354 μs (410 allocations: 18.03 KiB) +# 604.111 μs (1263 allocations: 71.78 KiB) +# nnlib - gpu +# 156.383 μs (440 allocations: 20.00 KiB) +# 835.374 μs (1969 allocations: 100.58 KiB) + +## Threadripper, NNlib v0.8.13 (fast_maximum) +# tullio +# 4.599 ms (71 allocations: 7.13 MiB) +# 20.699 ms (1118 allocations: 16.59 MiB) +# nalib +# 5.049 ms (84 allocations: 7.63 MiB) +# 22.252 ms (599 allocations: 14.57 MiB) +# nnlib +# 5.378 ms (81 allocations: 9.13 MiB) +# 23.453 ms (1049 allocations: 20.58 MiB) +# tullio - gpu +# 145.824 μs (520 allocations: 24.52 KiB) +# 915.305 μs (2221 allocations: 117.19 KiB) +# nalib - gpu +# 164.789 μs (410 allocations: 18.03 KiB) +# 610.835 μs (1263 allocations: 71.78 KiB) +# nnlib - gpu +# 157.785 μs (440 allocations: 20.00 KiB) +# 852.087 μs (1969 allocations: 100.58 KiB) + # function prof() # dim, len, batch_size, nheads = 128, 8, 128, 32; From 5745555deb60a603c8ec10e1b7115ad5c71fcf0d Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sun, 5 Mar 2023 12:34:17 +0100 Subject: [PATCH 14/21] finish up --- Project.toml | 5 - src/Flux.jl | 5 +- src/layers/attention.jl | 281 +++++++-------------------------- src/layers/attention_nnlib.jl | 151 ------------------ src/layers/attention_tullio.jl | 41 ----- test.jl | 37 +++++ test/layers/attention.jl | 63 ++++++++ test/test_utils.jl | 14 ++ 8 files changed, 173 insertions(+), 424 deletions(-) delete mode 100644 src/layers/attention_nnlib.jl delete mode 100644 src/layers/attention_tullio.jl create mode 100644 test.jl create mode 100644 test/layers/attention.jl diff --git a/Project.toml b/Project.toml index 4941e1b01a..c0a027a551 100644 --- a/Project.toml +++ b/Project.toml @@ -5,17 +5,13 @@ version = "0.13.14" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" -KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" -NeuralAttentionlib = "12afc1b8-fad6-47e1-9132-84abc478905f" OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Preferences = "21216c6a-2e73-6563-6e65-726566657250" @@ -26,7 +22,6 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] diff --git a/src/Flux.jl b/src/Flux.jl index a423311204..119e40a9ac 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -23,7 +23,9 @@ export Chain, Dense, Embedding, Maxout, SkipConnection, Parallel, PairwiseFusion RNN, LSTM, GRU, GRUv3, SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv, AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, - Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, + Dropout, AlphaDropout, + LayerNorm, BatchNorm, InstanceNorm, GroupNorm, + MultiHeadAttention, Upsample, PixelShuffle, fmap, cpu, gpu, f32, f64, f16, rand32, randn32, zeros32, ones32, testmode!, trainmode! @@ -59,6 +61,7 @@ include("layers/conv.jl") include("layers/recurrent.jl") include("layers/normalise.jl") include("layers/upsample.jl") +include("layers/attention.jl") include("layers/show.jl") include("loading.jl") diff --git a/src/layers/attention.jl b/src/layers/attention.jl index cb68928636..a3bcdad230 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -1,36 +1,32 @@ -using Flux, Functors, Test, LinearAlgebra, Random, Statistics -using CUDA -using NeuralAttentionlib -using NeuralAttentionlib: score_returning -using BenchmarkTools -using Flux: glorot_uniform -CUDA.allowscalar(false) const A3{T} = AbstractArray{T, 3} -const A4{T} = AbstractArray{T, 4} const TuplInt2 = Union{Int, Tuple{Int, Int}} const TuplInt3 = Union{Int, Tuple{Int, Int, Int}} -include("attention_nnlib.jl") -include("attention_tullio.jl") - - """ - MultiHeadAttention(dims, nheads; [bias, init, dropout_prob]) + MultiHeadAttention(dims; [nheads, bias, init, dropout_prob]) + +The multi-head dot-product attention layer used in Transformer architectures [1]. -Multi-head dot-product attention layer. +[1] Vaswani et al. "Attention is all you need." Advances in Neural Information Processing Systems. 2017. # Arguments -- `dims`: ... -- `nheads`: number of heads. -- `init`: weight initializer for the Dense layers. -- `bias` : whether pointwise QKVO dense transforms use bias. -- `dropout_prob`: dropout probability for the attention scores. +- `dims`: The embedding dimensions of inputs, intermediate tensors and outputs. + In the most general case, it is given as + `(q_in_dim, k_in_dim, v_in_dim) => (qk_dim, v_dim) => out_dim`. + Can take also simpler forms as + `dims::Int`, `in_dim::Int => (qk_dim, v_dim) => out_dim`, + `in_dim::Int => qkv_dim => out_dim`. + +- `nheads`: number of heads. Default `8`. +- `init`: weight initializer for the Dense layers. Default `glorot_uniform`. +- `bias` : whether pointwise QKVO dense transforms use bias. Default `false`. +- `dropout_prob`: dropout probability for the attention scores. Default `0.0`. # Forward - (::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask, withscores]) + (mha::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask, withscores]) - `q_in`: input query array of size `(q_in_dim, q_len, batch_size...)`. - `k_in`: input key array of size `(k_in_dim, kv_len, batch_size...)`. @@ -39,38 +35,58 @@ Multi-head dot-product attention layer. `(kv_len, q_len, nheads, batch_size)`. Default `nothing`. - `withscores`: Whether to return the attention scores. Default `false`. +In alternative, `mha(q_in)` is equivalent to `mha(q_in, q_in, q_in)` (self-attention) +and `mha(q_in, k_in)` is equivalent to `mha(q_in, k_in, k_in)` (key and value are the same). + + +See also [`NNlib.dot_product_attention`](@ref). + # Examples ```julia -mha = MultiHeadAttention(64, 8) +mha = MultiHeadAttention(64, nheads = 8) +q = rand(Float32, (64, 10, 32)) +k = rand(Float32, (64, 20, 32)) +v = rand(Float32, (64, 20, 32)) +y = mha(q, k, v) # [y] = [64, 10, 32] + +mha = MultiHeadAttention(64 => 1024 => 1024, nheads = 8) +y = mha(q) # self-attention; [y] = [1024, 10, 32] ``` """ struct MultiHeadAttention{P1, D, P2} nheads::Int - qkv_proj::P1 + q_proj::P1 + k_proj::P1 + v_proj::P1 attn_drop::D out_proj::P2 end @functor MultiHeadAttention -function MultiHeadAttention(dims, nheads::Int; +function MultiHeadAttention(dims; + nheads::Int = 8, bias::Bool = false, init = glorot_uniform, dropout_prob = 0.0) - dims = mha_process_dims(dims) + dims = normalize_mha_dims(dims) @assert dims.qk % nheads == 0 "qk_dim should be divisible by nheads" - qkv_proj = QKVProj(dims; bias, init) + @assert dims.v % nheads == 0 "v_dim should be divisible by nheads" + q_proj = Dense(dims.q_in => dims.qk; bias, init) + k_proj = Dense(dims.k_in => dims.qk; bias, init) + v_proj = Dense(dims.v_in => dims.v; bias, init) attn_drop = Dropout(dropout_prob) out_proj = Dense(dims.v => dims.out; bias, init) - return MultiHeadAttention(nheads, qkv_proj, attn_drop, out_proj) + return MultiHeadAttention(nheads, q_proj, k_proj, v_proj, attn_drop, out_proj) end -mha_process_dims(dims::Int) = +# turns the dims argument into a named tuple +normalize_mha_dims(dims::Int) = (; q_in=dims, k_in=dims, v_in=dims, qk=dims, v=dims, out=dims) -function mha_process_dims((in, (qkv, out))::Pair{<:TuplInt3, <:Pair{<:TuplInt2, Int}}) +function normalize_mha_dims((in, (qkv, out))::Pair{<:TuplInt3, <:Pair{<:TuplInt2, Int}}) if in isa Int q_in = k_in = v_in = in else @@ -85,209 +101,22 @@ function mha_process_dims((in, (qkv, out))::Pair{<:TuplInt3, <:Pair{<:TuplInt2, end # self-attention -(m::MultiHeadAttention)(qkv; kws...) = m(qkv, qkv, qkv; kws...) +(mha::MultiHeadAttention)(qkv; kws...) = mha(qkv, qkv, qkv; kws...) # key and value are the same -(m::MultiHeadAttention)(q, kv; kws...) = m(q, kv, kv; kws...) +(mha::MultiHeadAttention)(q, kv; kws...) = mha(q, kv, kv; kws...) -function (m::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3, bias=nothing; - withscores=false, mask=nothing, impl=:nnlib) +function (mha::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3, bias=nothing; + withscores=false, mask=nothing) ## [q_in] = [q_in_dim, q_len, batch_size] ## [k_in] = [k_in_dim, kv_len, batch_size] ## [v_in] = [v_in_dim, kv_len, batch_size] - - q, k, v = m.qkv_proj(q_in, k_in, v_in) - # [q] = [qk_dim, q_len, batch_size] - # [k] = [qk_dim, kv_len, batch_size] - # [v] = [v_dim, kv_len, batch_size] - - if impl == :tullio - x, α = dot_product_attention_tullio(m.nheads, q, k, v; mask, dropout=m.attn_drop) - elseif impl == :nalib - x, α = NeuralAttentionlib.multihead_qkv_attention(score_returning, m.nheads, q, k, v, mask) - elseif impl == :nnlib - x, α = dot_product_attention(q, k, v, bias; m.nheads, mask, fdrop=m.attn_drop) - else - error("Unknown attention implementation") - end - - x = m.out_proj(x) - + q = mha.q_proj(q_in) # [q] = [qk_dim, q_len, batch_size] + k = mha.k_proj(k_in) # [k] = [qk_dim, kv_len, batch_size] + v = mha.v_proj(v_in) # [v] = [v_dim, kv_len, batch_size] + x, α = NNlib.dot_product_attention(q, k, v, bias; mha.nheads, mask, fdrop=mha.attn_drop) + x = mha.out_proj(x) + # [x] = [out_dim, q_len, batch_size] + # [α] = [kv_len, q_len, nheads, batch_size] return withscores ? (x, α) : x end - -struct QKVProj - q_proj::Dense - k_proj::Dense - v_proj::Dense -end - -@functor QKVProj - -function QKVProj(dims; bias = false, init=glorot_uniform) - return QKVProj( - Dense(dims.q_in => dims.qk; bias, init), - Dense(dims.k_in => dims.qk; bias, init), - Dense(dims.v_in => dims.v; bias, init)) -end - -function (proj::QKVProj)(q_in, k_in, v_in) - return (proj.q_proj(q_in), proj.k_proj(k_in), proj.v_proj(v_in)) -end - -function perf(dim, len, batch_size, nheads) - mha = MultiHeadAttention(dim, nheads) - x = rand(Float32, (dim, len, batch_size)) - - println("tullio") - @btime $mha($x, impl=:tullio); - @btime gradient(m -> sum(m($x, impl=:tullio)), $mha); - - println("nalib") - @btime $mha($x, $x, $x, impl=:nalib); - @btime gradient(m -> sum(m($x, impl=:nalib)), $mha); - - println("nnlib") - @btime $mha($x, $x, $x, impl=:nnlib); - @btime gradient(m -> sum(m($x, impl=:nnlib)), $mha); - - if CUDA.functional() - mha_gpu = mha |> gpu - x_gpu = x |> gpu - - println("tullio - gpu") - @btime $mha_gpu($x_gpu, impl=:tullio); - @btime gradient(m -> sum(m($x_gpu, impl=:tullio)), $mha_gpu); - - println("nalib - gpu") - @btime CUDA.@sync $mha_gpu($x_gpu, impl=:nalib); - @btime CUDA.@sync gradient(m -> sum(m($x_gpu, impl=:nalib)), $mha_gpu); - - println("nnlib - gpu") - @btime CUDA.@sync $mha_gpu($x_gpu, impl=:nnlib); - @btime CUDA.@sync gradient(m -> sum(m($x_gpu, impl=:nnlib)), $mha_gpu); - end - return nothing -end - -function test(dim, nheads, len, batch_size) - mha = MultiHeadAttention(dim, nheads) - q = rand(Float32, (dim, len, batch_size)) - k = rand(Float32, (dim, len, batch_size)) - v = rand(Float32, (dim, len, batch_size)) - - y, α = mha(q, k, v, impl=:tullio, withscores=true) - @test y isa Array{Float32, 3} - @test size(y) == (dim, len, batch_size) - @test α isa Array{Float32, 4} - @test size(α) == (len, len, nheads, batch_size) - - y2, α2 = mha(q, k, v, impl=:nalib, withscores=true) - @test size(y) == size(y2) - @test y2 ≈ y - @test size(α) == size(α2) - @test α2 ≈ α - - y2b, α2b = mha(q, k, v, impl=:nnlib, withscores=true) - @test size(y) == size(y2b) - @test y2b ≈ y - @test size(α) == size(α2b) - @test α2b ≈ α - - mask = make_causal_mask(q) - y3, α3 = mha(q, k, v; impl=:tullio, withscores=true, mask) - y4, α4 = mha(q, k, v, impl=:nalib, withscores=true, mask=NeuralAttentionlib.CausalMask()) - @test y3 ≈ y4 - @test α3 ≈ α4 - - if CUDA.functional() - mha_gpu = mha |> gpu - q_gpu, k_gpu, v_gpu = q |> gpu, k |> gpu, v |> gpu - - y_gpu = mha_gpu(q_gpu, k_gpu, v_gpu, impl=:tullio) - y_gpu2 = mha_gpu(q_gpu, k_gpu, v_gpu, impl=:nalib) - @test Array(y_gpu) ≈ Array(y_gpu2) - @test Array(y_gpu) ≈ y - end - return nothing -end - -test(4, 2, 3, 1) - -perf(128, 8, 128, 32) - -## M1 Pro, NNlib v0.8.12 -# tullio -# 2.948 ms (77 allocations: 7.25 MiB) -# 15.041 ms (1124 allocations: 16.71 MiB) -# nalib -# 3.503 ms (89 allocations: 7.75 MiB) -# 15.828 ms (604 allocations: 14.70 MiB) -# nnlib -# 3.611 ms (87 allocations: 9.25 MiB) -# 16.497 ms (1055 allocations: 20.71 MiB) - -## M1 Pro, NNlib v0.8.13 (fast_maximum) -# tullio -# 2.427 ms (71 allocations: 7.13 MiB) -# 14.510 ms (1118 allocations: 16.59 MiB) -# nalib -# 3.052 ms (84 allocations: 7.63 MiB) -# 15.327 ms (599 allocations: 14.57 MiB) -# nnlib -# 3.166 ms (81 allocations: 9.13 MiB) -# 16.082 ms (1049 allocations: 20.58 MiB) - -## Threadripper, NNlib v0.8.12 -# tullio -# 5.658 ms (77 allocations: 7.25 MiB) -# 22.373 ms (1124 allocations: 16.71 MiB) -# nalib -# 6.187 ms (89 allocations: 7.75 MiB) -# 23.723 ms (604 allocations: 14.70 MiB) -# nnlib -# 6.473 ms (87 allocations: 9.25 MiB) -# 24.966 ms (1055 allocations: 20.71 MiB) -# tullio - gpu -# 145.332 μs (520 allocations: 24.52 KiB) -# 902.020 μs (2221 allocations: 117.19 KiB) -# nalib - gpu -# 162.354 μs (410 allocations: 18.03 KiB) -# 604.111 μs (1263 allocations: 71.78 KiB) -# nnlib - gpu -# 156.383 μs (440 allocations: 20.00 KiB) -# 835.374 μs (1969 allocations: 100.58 KiB) - -## Threadripper, NNlib v0.8.13 (fast_maximum) -# tullio -# 4.599 ms (71 allocations: 7.13 MiB) -# 20.699 ms (1118 allocations: 16.59 MiB) -# nalib -# 5.049 ms (84 allocations: 7.63 MiB) -# 22.252 ms (599 allocations: 14.57 MiB) -# nnlib -# 5.378 ms (81 allocations: 9.13 MiB) -# 23.453 ms (1049 allocations: 20.58 MiB) -# tullio - gpu -# 145.824 μs (520 allocations: 24.52 KiB) -# 915.305 μs (2221 allocations: 117.19 KiB) -# nalib - gpu -# 164.789 μs (410 allocations: 18.03 KiB) -# 610.835 μs (1263 allocations: 71.78 KiB) -# nnlib - gpu -# 157.785 μs (440 allocations: 20.00 KiB) -# 852.087 μs (1969 allocations: 100.58 KiB) - - -# function prof() - # dim, len, batch_size, nheads = 128, 8, 128, 32; - # # dim = 384; len = 128; batch_size = 32; nheads = 12 - # mha = MultiHeadAttention(dim, nheads) - # x = rand(Float32, (dim, len, batch_size)) - # @btime mha(x, impl=:tullio); - # @btime mha(x, impl=:nnlib); - # @profview mha(x, impl=:tullio); - # @profview prof(mha, x); - # y, α = mha(x; impl=:nnlib, withscores=true, mask) - # y2, α2 = mha(x; impl=:nalib, withscores=true, mask=NeuralAttentionlib.CausalMask()) -# end \ No newline at end of file diff --git a/src/layers/attention_nnlib.jl b/src/layers/attention_nnlib.jl deleted file mode 100644 index 3c95439b03..0000000000 --- a/src/layers/attention_nnlib.jl +++ /dev/null @@ -1,151 +0,0 @@ -using ChainRulesCore -using NNlib - -const AA3{T} = AbstractArray{T,3} -const AA4{T} = AbstractArray{T,4} -const AA{N,T} = AbstractArray{T,N} - -""" - dot_product_attention(query, key, value, [bias]; [fdrop, mask, nheads]) - -Multihead dot product attention used in transformer architectures. - -The input arrays must have the first two dimensions given by the number of features -and the sequece length, then an arbitrary number of batch dimensions or none. - -Returns the attention output array of size `(v_dim, q_len, batch_size...)` and the attention scores. -of size `(kv_len, q_len, nheads, batch_size...)`. - -See also [`dot_product_attention_scores`](@ref) if you only need the attention scores. - -# Arguments - -- `query`: Query array of size `(qk_dim, q_len, batch_size...)`. -- `key`: Key array of size `(qk_dim, kv_len, batch_size...)`. -- `value`: Value array of size `(v_dim, kv_len, batch_size...)`. -- `bias`: Either `nothing` or an array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. - It will be added to the attention scores before applying the softmax. Default `nothing`. -- `fdrop`: A dropout function or layer to be applied on the attention scores right after the softmax. - Default `identity` (no dropout). -- `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. - The mask is applied to the attention scores before the softmax. - Can also be set to `mask=:causal` to apply a causal mask. Default `nothing`. -- `nheads`: Number of heads to split the input arrays into. Default `1`. - -# Examples - -```julia -q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2) -y, α = dot_product_attention(q, k, v) -``` -""" -function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}, args...; kws...) where N - batch_size = size(q)[3:end] - batch_size == size(k)[3:end] == size(v)[3:end] || throw(ArgumentError("Batch dimensions have to be the same.")) - q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v)) - - x, α = dot_product_attention(q, k, v, args...; kws...) - - x = reshape(x, size(x, 1), size(x, 2), batch_size...) - α = reshape(α, size(α)[1:3]..., batch_size...) - return x, α -end - -function dot_product_attention(q::AA3, k::AA3, v::AA3, bias=nothing; - fdrop=identity, mask=nothing, nheads=1) - - (size(q, 3) == size(k, 3) == size(v, 3)) || throw(ArgumentError("Batch dimensions have to be the same.")) - size(q, 1) == size(k, 1) || throw(ArgumentError("First dimension in query and key has to be the same.")) - size(k, 2) == size(v, 2) || throw(ArgumentError("Second dimension in key and value has to be the same.")) - - # Multihead attention. TODO create fastpath for singlehead attention. - q, k, v = split_heads.((q, k, v), nheads) - x, α = _dot_product_attention(q, k, v, bias, fdrop, mask) - return join_heads(x), α -end - -function _dot_product_attention(q::AA4, k::AA4, v::AA4, bias, fdrop, mask) - # [q] = [qk_dim ÷ nheads, nheads, q_len, batch_size] - # [k] = [qk_dim ÷ nheads, nheads, kv_len, batch_size] - # [v] = [v_dim ÷ nheads, nheads, kv_len, batch_size] - - α = dot_product_attention_scores(q, k, bias; fdrop, mask) - # [α] = [kv_len, q_len, nheads, batch_size] - - # The following permutedims and batched_mul are equivalent to - # @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b] - vt = permutedims(v, (1, 3, 2, 4)) - x = batched_mul(vt, α) - x = permutedims(x, (1, 3, 2, 4)) - # [x] = [kv_dim ÷ nheads, nheads, q_len, batch_size] - return x, α -end - -""" - dot_product_attention_scores(query, key, [bias]; [fdrop, mask]) - -Return the attention scores for the [`dot_product_attention`](@ref). -Input arrays must have dimensions -`(num_features ÷ nheads, nheads, sequence_length, batch_size)`. - -See [`dot_product_attention`](@ref) for more details. -""" -function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing; - fdrop=identity, mask=nothing) where T - - # The following permutedims and batched_mul are equivalent to - # @tullio logits[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] / √T(qk_dim) - kt = permutedims(k, (3, 1, 2, 4)) - qt = permutedims(q, (1, 3, 2, 4)) ./ √T(size(q, 1)) - logits = batched_mul(kt, qt) - # [logits] = [kv_len, q_len, nheads, batch_size] - - if bias !== nothing - logits = logits .+ bias - end - - if mask !== nothing - if mask === :causal - mask = make_causal_mask(logits) - end - neginf = typemin(eltype(logits)) - logits = ifelse.(mask, logits, neginf) - end - - α = softmax(logits, dims=1) - return fdrop(α) -end - -""" - make_causal_mask(x, dims=2) - -Return a boolean square matrix `m` of the same type as `x` and of side `size(x, dims)`. -Its elements are set such that `m[i, j] == i ≤ j`. - -Can be used to mask the attention scores in [`dot_product_attention`](@ref). -""" -function make_causal_mask(x::AbstractArray; dims::Int=2) - len = size(x, dims) - mask = triu(trues_like(x, (len, len))) - return mask -end - -trues_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), true) -falses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false) - -split_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...) -join_heads(x) = reshape(x, :, size(x)[3:end]...) - -@non_differentiable make_causal_mask(x) -@non_differentiable trues_like(::Any...) -@non_differentiable falses_like(::Any...) - -function NNlib.batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N} - batch_size = size(x)[3:end] - @assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays." - x2 = reshape(x, size(x, 1), size(x, 2), :) - y2 = reshape(y, size(y, 1), size(y, 2), :) - z = batched_mul(x2, y2) - return reshape(z, size(z, 1), size(z, 2), batch_size...) - end - diff --git a/src/layers/attention_tullio.jl b/src/layers/attention_tullio.jl deleted file mode 100644 index 6d8a2ad6ec..0000000000 --- a/src/layers/attention_tullio.jl +++ /dev/null @@ -1,41 +0,0 @@ -using CUDAKernels, KernelAbstractions, LoopVectorization, Tullio - -reshape_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...) -flatten_heads(x) = reshape(x, :, size(x)[3:end]...) - -function dot_product_attention_tullio(nheads::Int, q::A3, k::A3, v::A3; kws...) - q, k, v = reshape_heads.((q, k, v), nheads) - x, α = dot_product_attention_tullio(q, k, v; kws...) - return flatten_heads(x), α -end - - -# Inspired by https://flax.readthedocs.io/en/latest/api_reference/_autosummary/flax.linen.dot_product_attention.html -function dot_product_attention_tullio(q::A4, k::A4, v::A4; - dropout=nothing, bias=nothing, mask=nothing) - - α = dot_product_attention_weights_tullio(q, k; dropout, bias, mask) - # [α] = [kv_len, q_len, nheads, batch_size] - @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b] - # [x] = [kv_dim ÷ nheads, nheads, q_len, batch_size] - return x, α -end - -function dot_product_attention_weights_tullio(q::A4{T}, k::A4{T}; - dropout=nothing, mask=nothing, bias=nothing) where T - - q = q ./ √T(size(q, 1)) - @tullio α[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] - # [α] = [kv_len, q_len, nheads, batch_size] - - if bias !== nothing - α = α .+ bias - end - if mask !== nothing - neginf = typemin(eltype(α)) - α = ifelse.(mask, α, neginf) - end - - α = softmax(α, dims=1) - return dropout === nothing ? α : dropout(α) -end diff --git a/test.jl b/test.jl new file mode 100644 index 0000000000..c51a899fd5 --- /dev/null +++ b/test.jl @@ -0,0 +1,37 @@ +using Flux, Test + + +@testset "attention" begin + dim = 4; nheads = 2; len = 3; batch_size = 5 + mha = MultiHeadAttention(dim, nheads) + q = rand(Float32, (dim, len, batch_size)) + k = rand(Float32, (dim, len, batch_size)) + v = rand(Float32, (dim, len, batch_size)) + + y, α = mha(q, k, v, withscores=true) + @test y isa Array{Float32, 3} + @test size(y) == (dim, len, batch_size) + @test α isa Array{Float32, 4} + @test size(α) == (len, len, nheads, batch_size) + + @testset "self-attention" begin + y1 = mha(q) + y2 = mha(q, q, q) + @test y1 ≈ y2 + end + + @testset "key and value are the same" begin + y1 = mha(q, k) + y2 = mha(q, k, k) + @test y1 ≈ y2 + end + + @testset "change dims" begin + dims = 4 => 10 => 5 + nhead = 5 + mha2 = MultiHeadAttention(dims, nheads) + y2 = mha2(q, k, v) + @test size(y2) == (dims.second.second, len, batch_size) + end +end + diff --git a/test/layers/attention.jl b/test/layers/attention.jl new file mode 100644 index 0000000000..809a2344b3 --- /dev/null +++ b/test/layers/attention.jl @@ -0,0 +1,63 @@ + + +@testset "attention" begin + dim = 4; nheads = 2; len = 3; batch_size = 5 + mha = MultiHeadAttention(dim; nheads) + q = rand(Float32, (dim, len, batch_size)) + k = rand(Float32, (dim, len, batch_size)) + v = rand(Float32, (dim, len, batch_size)) + + y, α = mha(q, k, v, withscores=true) + @test y isa Array{Float32, 3} + @test size(y) == (dim, len, batch_size) + @test α isa Array{Float32, 4} + @test size(α) == (len, len, nheads, batch_size) + + @testset "self-attention" begin + y1 = mha(q) + y2 = mha(q, q, q) + @test y1 ≈ y2 + end + + @testset "key and value are the same" begin + y1 = mha(q, k) + y2 = mha(q, k, k) + @test y1 ≈ y2 + end + + @testset "change dims" begin + dims = 4 => 10 => 5 + nhead = 5 + mha2 = MultiHeadAttention(dims; nheads) + y2 = mha2(q, k, v) + @test size(y2) == (dims.second.second, len, batch_size) + end + + @testset "mask" begin + mask = NNlib.make_causal_mask(q) + y, α = mha(q; mask, withscores=true) + @test all(α[2, 1, :, :] .== 0) + @test α[:, :, 1, 1] ≈ triu(α[:, :, 1, 1]) + end + + @testset "bias" begin + # use bias to produce a causal mask + b = zeros(Float32, (len, len)) + for i in 1:len, j in i:len + b[i, j] = typemax(Float32) + end + y, α = mha(q, k, v, b, withscores=true) + @test all(α[2, 1, :, :] .== 0) + @test α[:, :, 1, 1] ≈ triu(α[:, :, 1, 1]) + end + + @testset "gradient" begin + gm, gq = gradient(mha, q) do mha, q + y, α = mha(q, withscores=true) + return sum(y.^2) + sum(α.^2) + end + test_grad_type(gm, mha) + test_grad_type(gq, q) + end +end + diff --git a/test/test_utils.jl b/test/test_utils.jl index 2b07e59d08..12ff9d3855 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -82,3 +82,17 @@ function gpu_autodiff_test( end end end + + +test_grad_type(g::Nothing, x) = nothing + +function test_grad_type(g::AbstractArray{T1}, x::AbstractArray{T2}) where {T1, T2} + @test T1 == T2 + @test size(g) == size(x) +end + +function test_grad_type(g::NamedTuple, x::T) where T + for f in fieldnames(T) + test_grad_type(g[f], getfield(x, f)) + end +end From a1e83656b98738662b03d9b8cb197784787cf8e8 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sun, 5 Mar 2023 12:34:35 +0100 Subject: [PATCH 15/21] cleanup --- test.jl | 37 ------------------------------------- 1 file changed, 37 deletions(-) delete mode 100644 test.jl diff --git a/test.jl b/test.jl deleted file mode 100644 index c51a899fd5..0000000000 --- a/test.jl +++ /dev/null @@ -1,37 +0,0 @@ -using Flux, Test - - -@testset "attention" begin - dim = 4; nheads = 2; len = 3; batch_size = 5 - mha = MultiHeadAttention(dim, nheads) - q = rand(Float32, (dim, len, batch_size)) - k = rand(Float32, (dim, len, batch_size)) - v = rand(Float32, (dim, len, batch_size)) - - y, α = mha(q, k, v, withscores=true) - @test y isa Array{Float32, 3} - @test size(y) == (dim, len, batch_size) - @test α isa Array{Float32, 4} - @test size(α) == (len, len, nheads, batch_size) - - @testset "self-attention" begin - y1 = mha(q) - y2 = mha(q, q, q) - @test y1 ≈ y2 - end - - @testset "key and value are the same" begin - y1 = mha(q, k) - y2 = mha(q, k, k) - @test y1 ≈ y2 - end - - @testset "change dims" begin - dims = 4 => 10 => 5 - nhead = 5 - mha2 = MultiHeadAttention(dims, nheads) - y2 = mha2(q, k, v) - @test size(y2) == (dims.second.second, len, batch_size) - end -end - From 2ecf19ba1d4a0381e50f5b3d8c176eb6e171ef28 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sun, 5 Mar 2023 13:09:56 +0100 Subject: [PATCH 16/21] add cuda tests --- test/cuda/layers.jl | 26 ++++++++++++++++++++++++++ test/runtests.jl | 1 + test/test_utils.jl | 13 +++++++++++++ 3 files changed, 40 insertions(+) diff --git a/test/cuda/layers.jl b/test/cuda/layers.jl index a406c4129e..2218cb9d4e 100644 --- a/test/cuda/layers.jl +++ b/test/cuda/layers.jl @@ -338,3 +338,29 @@ end @test eltype(pool(reshape(gx,3,4,1))) == Float16 end end + +@testset "MultiHeadAttention" begin + dim = 4; nheads = 2; len = 3; batch_size = 5 + mha_cpu = MultiHeadAttention(dim; nheads) + x_cpu = rand(Float32, (dim, len, batch_size)) + y_cpu, α_cpu = mha_cpu(x_cpu, withscores=true) + + mha_gpu = mha_cpu |> gpu + x_gpu = x_cpu |> gpu + y_gpu, α_gpu = mha_gpu(x_gpu, withscores=true) + @test y_gpu isa CuArray{Float32} + @test α_gpu isa CuArray{Float32} + @test Array(y_gpu) ≈ y_cpu atol=1e-4 + @test Array(α_gpu) ≈ α_cpu atol=1e-4 + + gm_cpu, gx_cpu = gradient(mha_cpu, x_cpu) do mha, x + y, α = mha(x, withscores=true) + return sum(y.^2) + sum(α.^2) + end + gm_gpu, gx_gpu = gradient(mha_gpu, x_gpu) do mha, x + y, α = mha(x, withscores=true) + return sum(y.^2) + sum(α.^2) + end + test_grad_equal(gm_gpu, gm_cpu) + test_grad_equal(gx_gpu, gx_cpu) +end diff --git a/test/runtests.jl b/test/runtests.jl index a2a8f66323..a14372317c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,6 +33,7 @@ Random.seed!(0) end @testset "Layers" begin + include("layers/attention.jl") include("layers/basic.jl") include("layers/normalisation.jl") include("layers/stateless.jl") diff --git a/test/test_utils.jl b/test/test_utils.jl index 12ff9d3855..c52d266ab7 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -96,3 +96,16 @@ function test_grad_type(g::NamedTuple, x::T) where T test_grad_type(g[f], getfield(x, f)) end end + +test_grad_equal(g1::Nothing, g2::Nothing) = nothing + +function test_grad_equal(g1::AnyCuArray{T}, g2::Array{T}; atol=1e-4) where T + @test Array(g1) ≈ g2 atol=atol +end + +function test_grad_equal(g1::T1, g2::T2) where {T1 <: NamedTuple, T2 <: NamedTuple} + @test fieldnames(T1) == fieldnames(T2) + for f in fieldnames(T1) + test_grad_equal(g1[f], g2[f]) + end +end From b2809b256f412cbec7f01cc70860e9606ff93f06 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 10 Mar 2023 06:41:22 +0100 Subject: [PATCH 17/21] cleanup tests --- src/layers/attention.jl | 4 +-- test/cuda/layers.jl | 4 +-- test/test_utils.jl | 62 ++++++++++++++++++++--------------------- 3 files changed, 33 insertions(+), 37 deletions(-) diff --git a/src/layers/attention.jl b/src/layers/attention.jl index a3bcdad230..42ea053b39 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -1,7 +1,5 @@ const A3{T} = AbstractArray{T, 3} -const TuplInt2 = Union{Int, Tuple{Int, Int}} -const TuplInt3 = Union{Int, Tuple{Int, Int, Int}} """ MultiHeadAttention(dims; [nheads, bias, init, dropout_prob]) @@ -86,7 +84,7 @@ end normalize_mha_dims(dims::Int) = (; q_in=dims, k_in=dims, v_in=dims, qk=dims, v=dims, out=dims) -function normalize_mha_dims((in, (qkv, out))::Pair{<:TuplInt3, <:Pair{<:TuplInt2, Int}}) +function normalize_mha_dims((in, (qkv, out))::Pair{<:Dims{3}, <:Pair{<:Dims{2}, Int}}) if in isa Int q_in = k_in = v_in = in else diff --git a/test/cuda/layers.jl b/test/cuda/layers.jl index 2218cb9d4e..64a2663082 100644 --- a/test/cuda/layers.jl +++ b/test/cuda/layers.jl @@ -361,6 +361,6 @@ end y, α = mha(x, withscores=true) return sum(y.^2) + sum(α.^2) end - test_grad_equal(gm_gpu, gm_cpu) - test_grad_equal(gx_gpu, gx_cpu) + check_grad(gm_gpu, gm_cpu) + check_grad(gx_gpu, gx_cpu) end diff --git a/test/test_utils.jl b/test/test_utils.jl index c52d266ab7..f71110877a 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -1,27 +1,33 @@ -function check_grad(g_gpu, g_cpu, atol, rtol; allow_nothing::Bool) +function check_grad(g_gpu, g_cpu; + rtol=1e-4, atol=1e-4, + allow_nothing::Bool=false) allow_nothing && return @show g_gpu g_cpu @test false end -check_grad(g_gpu::Base.RefValue, g_cpu::Base.RefValue, atol, rtol; allow_nothing::Bool) = - check_grad(g_gpu[], g_cpu[], atol, rtol; allow_nothing) -check_grad(g_gpu::Nothing, g_cpu::Nothing, atol, rtol; allow_nothing::Bool) = + +check_grad(g_gpu::Base.RefValue, g_cpu::Base.RefValue; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) = + check_grad(g_gpu[], g_cpu[]; rtol, atol, allow_nothing) + +check_grad(g_gpu::Nothing, g_cpu::Nothing; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) = @test true -check_grad(g_gpu::Float32, g_cpu::Float32, atol, rtol; allow_nothing::Bool) = + +check_grad(g_gpu::Float32, g_cpu::Float32; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) = @test g_cpu ≈ g_gpu rtol=rtol atol=atol -check_grad(g_gpu::CuArray{Float32}, g_cpu::Array{Float32}, atol, rtol; allow_nothing::Bool) = + +check_grad(g_gpu::CuArray{Float32}, g_cpu::Array{Float32}; rtol=eps32, allow_nothing::Bool=false) = @test g_cpu ≈ collect(g_gpu) rtol=rtol atol=atol -function check_grad(g_gpu::Tuple, g_cpu::Tuple, atol, rtol; allow_nothing::Bool) +function check_grad(g_gpu::Tuple, g_cpu::Tuple; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) for (v1, v2) in zip(g_gpu, g_cpu) - check_grad(v1, v2, atol, rtol; allow_nothing) + check_grad(v1, v2; rtol, atol, allow_nothing) end end -function check_grad(g_gpu::NamedTuple, g_cpu::NamedTuple, atol, rtol; allow_nothing::Bool) +function check_grad(g_gpu::NamedTuple, g_cpu::NamedTuple; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) for ((k1,v1), (k2,v2)) in zip(pairs(g_gpu), pairs(g_cpu)) @test k1 == k2 - check_grad(v1, v2, atol, rtol; allow_nothing) + check_grad(v1, v2; rtol, atol, allow_nothing) end end @@ -31,10 +37,14 @@ check_type(x::CuArray{Float32}) = true check_type(x::Array{Float32}) = true function gpu_autodiff_test( - f_cpu, xs_cpu::Array{Float32}...; - test_equal=true, rtol=1e-4, atol=1e-4, - checkgrad::Bool = true, allow_nothing::Bool = false, -) + f_cpu, + xs_cpu::Array{Float32}...; + test_equal=true, + rtol=1e-4, atol=1e-4, + checkgrad::Bool = true, + allow_nothing::Bool = false, + ) + # Compare CPU & GPU function outputs. f_gpu = f_cpu |> gpu xs_gpu = gpu.(xs_cpu) @@ -60,7 +70,7 @@ function gpu_autodiff_test( if test_equal @test collect(y_cpu) ≈ collect(y_gpu) rtol=rtol atol=atol for (g_gpu, g_cpu) in zip(gs_gpu, gs_cpu) - check_grad(g_gpu, g_cpu, atol, rtol; allow_nothing) + check_grad(g_gpu, g_cpu; atol, rtol, allow_nothing) end end @@ -78,34 +88,22 @@ function gpu_autodiff_test( @test collect(y_cpu) ≈ collect(y_gpu) rtol=rtol atol=atol @assert length(ps_gpu) == length(ps_cpu) for (p_gpu, p_cpu) in zip(ps_gpu, ps_cpu) - check_grad(gs_gpu[p_gpu], gs_cpu[p_cpu], atol, rtol; allow_nothing) + check_grad(gs_gpu[p_gpu], gs_cpu[p_cpu]; atol, rtol, allow_nothing) end end end +# check_grad_type checks that the gradient type matches the primal type. -test_grad_type(g::Nothing, x) = nothing +check_grad_type(g::Nothing, x) = nothing -function test_grad_type(g::AbstractArray{T1}, x::AbstractArray{T2}) where {T1, T2} +function check_grad_type(g::AbstractArray{T1}, x::AbstractArray{T2}) where {T1, T2} @test T1 == T2 @test size(g) == size(x) end -function test_grad_type(g::NamedTuple, x::T) where T +function check_grad_type(g::NamedTuple, x::T) where T for f in fieldnames(T) test_grad_type(g[f], getfield(x, f)) end end - -test_grad_equal(g1::Nothing, g2::Nothing) = nothing - -function test_grad_equal(g1::AnyCuArray{T}, g2::Array{T}; atol=1e-4) where T - @test Array(g1) ≈ g2 atol=atol -end - -function test_grad_equal(g1::T1, g2::T2) where {T1 <: NamedTuple, T2 <: NamedTuple} - @test fieldnames(T1) == fieldnames(T2) - for f in fieldnames(T1) - test_grad_equal(g1[f], g2[f]) - end -end From bd28c54a50c47b248cd2fd0879613c7bd960367b Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Fri, 10 Mar 2023 07:02:49 +0100 Subject: [PATCH 18/21] IntOrDims --- src/layers/attention.jl | 3 +- test/layers/attention.jl | 4 +-- test/runtests.jl | 70 ++++++++++++++++++++-------------------- test/test_utils.jl | 4 +-- 4 files changed, 41 insertions(+), 40 deletions(-) diff --git a/src/layers/attention.jl b/src/layers/attention.jl index 42ea053b39..54f2cdc43a 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -1,5 +1,6 @@ const A3{T} = AbstractArray{T, 3} +const IntOrDims{N} = Union{Int, Dims{N}} """ MultiHeadAttention(dims; [nheads, bias, init, dropout_prob]) @@ -84,7 +85,7 @@ end normalize_mha_dims(dims::Int) = (; q_in=dims, k_in=dims, v_in=dims, qk=dims, v=dims, out=dims) -function normalize_mha_dims((in, (qkv, out))::Pair{<:Dims{3}, <:Pair{<:Dims{2}, Int}}) +function normalize_mha_dims((in, (qkv, out))::Pair{<:IntOrDims{3}, <:Pair{<:IntOrDims{2}, Int}}) if in isa Int q_in = k_in = v_in = in else diff --git a/test/layers/attention.jl b/test/layers/attention.jl index 809a2344b3..15485e4bad 100644 --- a/test/layers/attention.jl +++ b/test/layers/attention.jl @@ -56,8 +56,8 @@ y, α = mha(q, withscores=true) return sum(y.^2) + sum(α.^2) end - test_grad_type(gm, mha) - test_grad_type(gq, q) + check_grad_type(gm, mha) + check_grad_type(gq, q) end end diff --git a/test/runtests.jl b/test/runtests.jl index a14372317c..137dbeb70a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,48 +13,48 @@ Random.seed!(0) @testset verbose=true "Flux.jl" begin - @testset "Utils" begin - include("utils.jl") - end + # @testset "Utils" begin + # include("utils.jl") + # end - @testset "Optimise / Train" begin - include("optimise.jl") - include("train.jl") - end + # @testset "Optimise / Train" begin + # include("optimise.jl") + # include("train.jl") + # end - @testset "Data" begin - include("data.jl") - end + # @testset "Data" begin + # include("data.jl") + # end - @testset "Losses" begin - include("losses.jl") - include("ctc.jl") - CUDA.functional() && include("ctc-gpu.jl") - end + # @testset "Losses" begin + # include("losses.jl") + # include("ctc.jl") + # CUDA.functional() && include("ctc-gpu.jl") + # end - @testset "Layers" begin + # @testset "Layers" begin include("layers/attention.jl") - include("layers/basic.jl") - include("layers/normalisation.jl") - include("layers/stateless.jl") - include("layers/recurrent.jl") - include("layers/conv.jl") - include("layers/upsample.jl") - include("layers/show.jl") - end + # include("layers/basic.jl") + # include("layers/normalisation.jl") + # include("layers/stateless.jl") + # include("layers/recurrent.jl") + # include("layers/conv.jl") + # include("layers/upsample.jl") + # include("layers/show.jl") + # end - @testset "outputsize" begin - using Flux: outputsize - include("outputsize.jl") - end + # @testset "outputsize" begin + # using Flux: outputsize + # include("outputsize.jl") + # end - @testset "CUDA" begin - if CUDA.functional() - include("cuda/runtests.jl") - else - @warn "CUDA unavailable, not testing GPU support" - end - end + # @testset "CUDA" begin + # if CUDA.functional() + # include("cuda/runtests.jl") + # else + # @warn "CUDA unavailable, not testing GPU support" + # end + # end @static if VERSION == v"1.6" using Documenter diff --git a/test/test_utils.jl b/test/test_utils.jl index f71110877a..f07fb1c721 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -15,7 +15,7 @@ check_grad(g_gpu::Nothing, g_cpu::Nothing; rtol=1e-4, atol=1e-4, allow_nothing:: check_grad(g_gpu::Float32, g_cpu::Float32; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) = @test g_cpu ≈ g_gpu rtol=rtol atol=atol -check_grad(g_gpu::CuArray{Float32}, g_cpu::Array{Float32}; rtol=eps32, allow_nothing::Bool=false) = +check_grad(g_gpu::CuArray{Float32}, g_cpu::Array{Float32}; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) = @test g_cpu ≈ collect(g_gpu) rtol=rtol atol=atol function check_grad(g_gpu::Tuple, g_cpu::Tuple; rtol=1e-4, atol=1e-4, allow_nothing::Bool=false) @@ -104,6 +104,6 @@ end function check_grad_type(g::NamedTuple, x::T) where T for f in fieldnames(T) - test_grad_type(g[f], getfield(x, f)) + check_grad_type(g[f], getfield(x, f)) end end From 111d8145fb48b3c8a9e30ccd9c66ed3a7e68c51c Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Fri, 10 Mar 2023 07:04:37 +0100 Subject: [PATCH 19/21] cleanup --- test/runtests.jl | 70 ++++++++++++++++++++++++------------------------ 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 137dbeb70a..a14372317c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -13,48 +13,48 @@ Random.seed!(0) @testset verbose=true "Flux.jl" begin - # @testset "Utils" begin - # include("utils.jl") - # end + @testset "Utils" begin + include("utils.jl") + end - # @testset "Optimise / Train" begin - # include("optimise.jl") - # include("train.jl") - # end + @testset "Optimise / Train" begin + include("optimise.jl") + include("train.jl") + end - # @testset "Data" begin - # include("data.jl") - # end + @testset "Data" begin + include("data.jl") + end - # @testset "Losses" begin - # include("losses.jl") - # include("ctc.jl") - # CUDA.functional() && include("ctc-gpu.jl") - # end + @testset "Losses" begin + include("losses.jl") + include("ctc.jl") + CUDA.functional() && include("ctc-gpu.jl") + end - # @testset "Layers" begin + @testset "Layers" begin include("layers/attention.jl") - # include("layers/basic.jl") - # include("layers/normalisation.jl") - # include("layers/stateless.jl") - # include("layers/recurrent.jl") - # include("layers/conv.jl") - # include("layers/upsample.jl") - # include("layers/show.jl") - # end + include("layers/basic.jl") + include("layers/normalisation.jl") + include("layers/stateless.jl") + include("layers/recurrent.jl") + include("layers/conv.jl") + include("layers/upsample.jl") + include("layers/show.jl") + end - # @testset "outputsize" begin - # using Flux: outputsize - # include("outputsize.jl") - # end + @testset "outputsize" begin + using Flux: outputsize + include("outputsize.jl") + end - # @testset "CUDA" begin - # if CUDA.functional() - # include("cuda/runtests.jl") - # else - # @warn "CUDA unavailable, not testing GPU support" - # end - # end + @testset "CUDA" begin + if CUDA.functional() + include("cuda/runtests.jl") + else + @warn "CUDA unavailable, not testing GPU support" + end + end @static if VERSION == v"1.6" using Documenter From 0b108231f1ff90ac410e2c2f6af2a4d55d647179 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Fri, 10 Mar 2023 07:29:42 +0100 Subject: [PATCH 20/21] remove with_scores --- src/layers/attention.jl | 28 ++++++++++++++++------------ test/cuda/layers.jl | 8 ++++---- test/layers/attention.jl | 20 +++++++++++--------- 3 files changed, 31 insertions(+), 25 deletions(-) diff --git a/src/layers/attention.jl b/src/layers/attention.jl index 54f2cdc43a..56f1e36427 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -7,17 +7,19 @@ const IntOrDims{N} = Union{Int, Dims{N}} The multi-head dot-product attention layer used in Transformer architectures [1]. +Returns the transformed input sequnce and the attention scores. + [1] Vaswani et al. "Attention is all you need." Advances in Neural Information Processing Systems. 2017. # Arguments - `dims`: The embedding dimensions of inputs, intermediate tensors and outputs. In the most general case, it is given as - `(q_in_dim, k_in_dim, v_in_dim) => (qk_dim, v_dim) => out_dim`. + a) `(q_in_dim, k_in_dim, v_in_dim) => (qk_dim, v_dim) => out_dim`. Can take also simpler forms as - `dims::Int`, `in_dim::Int => (qk_dim, v_dim) => out_dim`, - `in_dim::Int => qkv_dim => out_dim`. - + b) `dims::Int`; + c) `in_dim::Int => (qk_dim, v_dim) => out_dim`; + d) `in_dim::Int => qkv_dim => out_dim`. - `nheads`: number of heads. Default `8`. - `init`: weight initializer for the Dense layers. Default `glorot_uniform`. - `bias` : whether pointwise QKVO dense transforms use bias. Default `false`. @@ -25,19 +27,17 @@ The multi-head dot-product attention layer used in Transformer architectures [1] # Forward - (mha::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask, withscores]) + (mha::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask]) - `q_in`: input query array of size `(q_in_dim, q_len, batch_size...)`. - `k_in`: input key array of size `(k_in_dim, kv_len, batch_size...)`. - `v_in`: input value array of size `(v_in_dim, kv_len, batch_size...)`. - `mask`: input array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. Default `nothing`. -- `withscores`: Whether to return the attention scores. Default `false`. In alternative, `mha(q_in)` is equivalent to `mha(q_in, q_in, q_in)` (self-attention) and `mha(q_in, k_in)` is equivalent to `mha(q_in, k_in, k_in)` (key and value are the same). - See also [`NNlib.dot_product_attention`](@ref). # Examples @@ -47,10 +47,14 @@ mha = MultiHeadAttention(64, nheads = 8) q = rand(Float32, (64, 10, 32)) k = rand(Float32, (64, 20, 32)) v = rand(Float32, (64, 20, 32)) -y = mha(q, k, v) # [y] = [64, 10, 32] +y, α = mha(q, k, v) +# [y] = [64, 10, 32] +# [α] = [20, 10, 8, 32] mha = MultiHeadAttention(64 => 1024 => 1024, nheads = 8) -y = mha(q) # self-attention; [y] = [1024, 10, 32] +y, α = mha(q) # self-attention +# [y] = [1024, 10, 32] +# [α] = [10, 10, 8, 32] ``` """ struct MultiHeadAttention{P1, D, P2} @@ -105,8 +109,8 @@ end # key and value are the same (mha::MultiHeadAttention)(q, kv; kws...) = mha(q, kv, kv; kws...) -function (mha::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3, bias=nothing; - withscores=false, mask=nothing) +function (mha::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3, + bias=nothing; mask=nothing) ## [q_in] = [q_in_dim, q_len, batch_size] ## [k_in] = [k_in_dim, kv_len, batch_size] ## [v_in] = [v_in_dim, kv_len, batch_size] @@ -117,5 +121,5 @@ function (mha::MultiHeadAttention)(q_in::A3, k_in::A3, v_in::A3, bias=nothing; x = mha.out_proj(x) # [x] = [out_dim, q_len, batch_size] # [α] = [kv_len, q_len, nheads, batch_size] - return withscores ? (x, α) : x + return x, α end diff --git a/test/cuda/layers.jl b/test/cuda/layers.jl index 64a2663082..90c7ab0b40 100644 --- a/test/cuda/layers.jl +++ b/test/cuda/layers.jl @@ -343,22 +343,22 @@ end dim = 4; nheads = 2; len = 3; batch_size = 5 mha_cpu = MultiHeadAttention(dim; nheads) x_cpu = rand(Float32, (dim, len, batch_size)) - y_cpu, α_cpu = mha_cpu(x_cpu, withscores=true) + y_cpu, α_cpu = mha_cpu(x_cpu) mha_gpu = mha_cpu |> gpu x_gpu = x_cpu |> gpu - y_gpu, α_gpu = mha_gpu(x_gpu, withscores=true) + y_gpu, α_gpu = mha_gpu(x_gpu) @test y_gpu isa CuArray{Float32} @test α_gpu isa CuArray{Float32} @test Array(y_gpu) ≈ y_cpu atol=1e-4 @test Array(α_gpu) ≈ α_cpu atol=1e-4 gm_cpu, gx_cpu = gradient(mha_cpu, x_cpu) do mha, x - y, α = mha(x, withscores=true) + y, α = mha(x) return sum(y.^2) + sum(α.^2) end gm_gpu, gx_gpu = gradient(mha_gpu, x_gpu) do mha, x - y, α = mha(x, withscores=true) + y, α = mha(x) return sum(y.^2) + sum(α.^2) end check_grad(gm_gpu, gm_cpu) diff --git a/test/layers/attention.jl b/test/layers/attention.jl index 15485e4bad..a4c90b36ed 100644 --- a/test/layers/attention.jl +++ b/test/layers/attention.jl @@ -7,35 +7,37 @@ k = rand(Float32, (dim, len, batch_size)) v = rand(Float32, (dim, len, batch_size)) - y, α = mha(q, k, v, withscores=true) + y, α = mha(q, k, v) @test y isa Array{Float32, 3} @test size(y) == (dim, len, batch_size) @test α isa Array{Float32, 4} @test size(α) == (len, len, nheads, batch_size) @testset "self-attention" begin - y1 = mha(q) - y2 = mha(q, q, q) + y1, α1 = mha(q) + y2, α2 = mha(q, q, q) @test y1 ≈ y2 + @test α1 ≈ α2 end @testset "key and value are the same" begin - y1 = mha(q, k) - y2 = mha(q, k, k) + y1, α1 = mha(q, k) + y2, α2 = mha(q, k, k) @test y1 ≈ y2 + @test α1 ≈ α2 end @testset "change dims" begin dims = 4 => 10 => 5 nhead = 5 mha2 = MultiHeadAttention(dims; nheads) - y2 = mha2(q, k, v) + y2, _ = mha2(q, k, v) @test size(y2) == (dims.second.second, len, batch_size) end @testset "mask" begin mask = NNlib.make_causal_mask(q) - y, α = mha(q; mask, withscores=true) + y, α = mha(q; mask) @test all(α[2, 1, :, :] .== 0) @test α[:, :, 1, 1] ≈ triu(α[:, :, 1, 1]) end @@ -46,14 +48,14 @@ for i in 1:len, j in i:len b[i, j] = typemax(Float32) end - y, α = mha(q, k, v, b, withscores=true) + y, α = mha(q, k, v, b) @test all(α[2, 1, :, :] .== 0) @test α[:, :, 1, 1] ≈ triu(α[:, :, 1, 1]) end @testset "gradient" begin gm, gq = gradient(mha, q) do mha, q - y, α = mha(q, withscores=true) + y, α = mha(q) return sum(y.^2) + sum(α.^2) end check_grad_type(gm, mha) From 29afec7f20fbb3a86ddd66309a42dce4976b457f Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Fri, 10 Mar 2023 07:44:17 +0100 Subject: [PATCH 21/21] improve docstring --- src/layers/attention.jl | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/layers/attention.jl b/src/layers/attention.jl index 56f1e36427..5b4dacfaf1 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -29,14 +29,22 @@ Returns the transformed input sequnce and the attention scores. (mha::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask]) -- `q_in`: input query array of size `(q_in_dim, q_len, batch_size...)`. -- `k_in`: input key array of size `(k_in_dim, kv_len, batch_size...)`. -- `v_in`: input value array of size `(v_in_dim, kv_len, batch_size...)`. -- `mask`: input array broadcastable to size - `(kv_len, q_len, nheads, batch_size)`. Default `nothing`. - -In alternative, `mha(q_in)` is equivalent to `mha(q_in, q_in, q_in)` (self-attention) -and `mha(q_in, k_in)` is equivalent to `mha(q_in, k_in, k_in)` (key and value are the same). +The arguments of the forward pass are: + +- `q_in`: Input query array of size `(q_in_dim, q_len, batch_size)`. +- `k_in`: Input key array of size `(k_in_dim, kv_len, batch_size)`. +- `v_in`: Input value array of size `(v_in_dim, kv_len, batch_size)`. +- `bias`: Bias array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. + It will be added to the attention scores before the softmax. + Default `nothing`. +- `mask`: Input array broadcastable to size + `(kv_len, q_len, nheads, batch_size)`. + The mask is applied to the attention scores just before the softmax. + See [`NNlib.make_causal_mask`](@ref) for creating causal masks. + Default `nothing`. + +Alternative calling signatures are `mha(q_in)`, equivalent to `mha(q_in, q_in, q_in)` (self-attention), +and `mha(q_in, k_in)`, equivalent to `mha(q_in, k_in, k_in)` (key and value are the same). See also [`NNlib.dot_product_attention`](@ref).