1- using Flux, Functors, Test, LinearAlgebra, Random, Statistics
2- using CUDA
3- using NeuralAttentionlib
4- using NeuralAttentionlib: score_returning
5- using BenchmarkTools
6- using Flux: glorot_uniform
7- CUDA. allowscalar (false )
81
92const A3{T} = AbstractArray{T, 3 }
10- const A4{T} = AbstractArray{T, 4 }
113const TuplInt2 = Union{Int, Tuple{Int, Int}}
124const TuplInt3 = Union{Int, Tuple{Int, Int, Int}}
135
14- include (" attention_nnlib.jl" )
15- include (" attention_tullio.jl" )
16-
17-
186"""
19- MultiHeadAttention(dims, nheads; [bias, init, dropout_prob])
7+ MultiHeadAttention(dims; [nheads, bias, init, dropout_prob])
8+
9+ The multi-head dot-product attention layer used in Transformer architectures [1].
2010
21- Multi-head dot-product attention layer .
11+ [1] Vaswani et al. "Attention is all you need." Advances in Neural Information Processing Systems. 2017 .
2212
2313# Arguments
2414
25- - `dims`: ...
26- - `nheads`: number of heads.
27- - `init`: weight initializer for the Dense layers.
28- - `bias` : whether pointwise QKVO dense transforms use bias.
29- - `dropout_prob`: dropout probability for the attention scores.
15+ - `dims`: The embedding dimensions of inputs, intermediate tensors and outputs.
16+ In the most general case, it is given as
17+ `(q_in_dim, k_in_dim, v_in_dim) => (qk_dim, v_dim) => out_dim`.
18+ Can take also simpler forms as
19+ `dims::Int`, `in_dim::Int => (qk_dim, v_dim) => out_dim`,
20+ `in_dim::Int => qkv_dim => out_dim`.
21+
22+ - `nheads`: number of heads. Default `8`.
23+ - `init`: weight initializer for the Dense layers. Default `glorot_uniform`.
24+ - `bias` : whether pointwise QKVO dense transforms use bias. Default `false`.
25+ - `dropout_prob`: dropout probability for the attention scores. Default `0.0`.
3026
3127# Forward
3228
33- (::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask, withscores])
29+ (mha ::MultiHeadAttention)(q_in, k_in, v_in, [bias]; [mask, withscores])
3430
3531- `q_in`: input query array of size `(q_in_dim, q_len, batch_size...)`.
3632- `k_in`: input key array of size `(k_in_dim, kv_len, batch_size...)`.
@@ -39,38 +35,58 @@ Multi-head dot-product attention layer.
3935 `(kv_len, q_len, nheads, batch_size)`. Default `nothing`.
4036- `withscores`: Whether to return the attention scores. Default `false`.
4137
38+ In alternative, `mha(q_in)` is equivalent to `mha(q_in, q_in, q_in)` (self-attention)
39+ and `mha(q_in, k_in)` is equivalent to `mha(q_in, k_in, k_in)` (key and value are the same).
40+
41+
42+ See also [`NNlib.dot_product_attention`](@ref).
43+
4244# Examples
4345
4446```julia
45- mha = MultiHeadAttention(64, 8)
47+ mha = MultiHeadAttention(64, nheads = 8)
48+ q = rand(Float32, (64, 10, 32))
49+ k = rand(Float32, (64, 20, 32))
50+ v = rand(Float32, (64, 20, 32))
51+ y = mha(q, k, v) # [y] = [64, 10, 32]
52+
53+ mha = MultiHeadAttention(64 => 1024 => 1024, nheads = 8)
54+ y = mha(q) # self-attention; [y] = [1024, 10, 32]
4655```
4756"""
4857struct MultiHeadAttention{P1, D, P2}
4958 nheads:: Int
50- qkv_proj:: P1
59+ q_proj:: P1
60+ k_proj:: P1
61+ v_proj:: P1
5162 attn_drop:: D
5263 out_proj:: P2
5364end
5465
5566@functor MultiHeadAttention
5667
57- function MultiHeadAttention (dims, nheads:: Int ;
68+ function MultiHeadAttention (dims;
69+ nheads:: Int = 8 ,
5870 bias:: Bool = false ,
5971 init = glorot_uniform,
6072 dropout_prob = 0.0 )
6173
62- dims = mha_process_dims (dims)
74+ dims = normalize_mha_dims (dims)
6375 @assert dims. qk % nheads == 0 " qk_dim should be divisible by nheads"
64- qkv_proj = QKVProj (dims; bias, init)
76+ @assert dims. v % nheads == 0 " v_dim should be divisible by nheads"
77+ q_proj = Dense (dims. q_in => dims. qk; bias, init)
78+ k_proj = Dense (dims. k_in => dims. qk; bias, init)
79+ v_proj = Dense (dims. v_in => dims. v; bias, init)
6580 attn_drop = Dropout (dropout_prob)
6681 out_proj = Dense (dims. v => dims. out; bias, init)
67- return MultiHeadAttention (nheads, qkv_proj , attn_drop, out_proj)
82+ return MultiHeadAttention (nheads, q_proj, k_proj, v_proj , attn_drop, out_proj)
6883end
6984
70- mha_process_dims (dims:: Int ) =
85+ # turns the dims argument into a named tuple
86+ normalize_mha_dims (dims:: Int ) =
7187 (; q_in= dims, k_in= dims, v_in= dims, qk= dims, v= dims, out= dims)
7288
73- function mha_process_dims ((in, (qkv, out)):: Pair{<:TuplInt3, <:Pair{<:TuplInt2, Int}} )
89+ function normalize_mha_dims ((in, (qkv, out)):: Pair{<:TuplInt3, <:Pair{<:TuplInt2, Int}} )
7490 if in isa Int
7591 q_in = k_in = v_in = in
7692 else
@@ -85,209 +101,22 @@ function mha_process_dims((in, (qkv, out))::Pair{<:TuplInt3, <:Pair{<:TuplInt2,
85101end
86102
87103# self-attention
88- (m :: MultiHeadAttention )(qkv; kws... ) = m (qkv, qkv, qkv; kws... )
104+ (mha :: MultiHeadAttention )(qkv; kws... ) = mha (qkv, qkv, qkv; kws... )
89105
90106# key and value are the same
91- (m :: MultiHeadAttention )(q, kv; kws... ) = m (q, kv, kv; kws... )
107+ (mha :: MultiHeadAttention )(q, kv; kws... ) = mha (q, kv, kv; kws... )
92108
93- function (m :: MultiHeadAttention )(q_in:: A3 , k_in:: A3 , v_in:: A3 , bias= nothing ;
94- withscores= false , mask= nothing , impl = :nnlib )
109+ function (mha :: MultiHeadAttention )(q_in:: A3 , k_in:: A3 , v_in:: A3 , bias= nothing ;
110+ withscores= false , mask= nothing )
95111 # # [q_in] = [q_in_dim, q_len, batch_size]
96112 # # [k_in] = [k_in_dim, kv_len, batch_size]
97113 # # [v_in] = [v_in_dim, kv_len, batch_size]
98-
99- q, k, v = m. qkv_proj (q_in, k_in, v_in)
100- # [q] = [qk_dim, q_len, batch_size]
101- # [k] = [qk_dim, kv_len, batch_size]
102- # [v] = [v_dim, kv_len, batch_size]
103-
104- if impl == :tullio
105- x, α = dot_product_attention_tullio (m. nheads, q, k, v; mask, dropout= m. attn_drop)
106- elseif impl == :nalib
107- x, α = NeuralAttentionlib. multihead_qkv_attention (score_returning, m. nheads, q, k, v, mask)
108- elseif impl == :nnlib
109- x, α = dot_product_attention (q, k, v, bias; m. nheads, mask, fdrop= m. attn_drop)
110- else
111- error (" Unknown attention implementation" )
112- end
113-
114- x = m. out_proj (x)
115-
114+ q = mha. q_proj (q_in) # [q] = [qk_dim, q_len, batch_size]
115+ k = mha. k_proj (k_in) # [k] = [qk_dim, kv_len, batch_size]
116+ v = mha. v_proj (v_in) # [v] = [v_dim, kv_len, batch_size]
117+ x, α = NNlib. dot_product_attention (q, k, v, bias; mha. nheads, mask, fdrop= mha. attn_drop)
118+ x = mha. out_proj (x)
119+ # [x] = [out_dim, q_len, batch_size]
120+ # [α] = [kv_len, q_len, nheads, batch_size]
116121 return withscores ? (x, α) : x
117122end
118-
119- struct QKVProj
120- q_proj:: Dense
121- k_proj:: Dense
122- v_proj:: Dense
123- end
124-
125- @functor QKVProj
126-
127- function QKVProj (dims; bias = false , init= glorot_uniform)
128- return QKVProj (
129- Dense (dims. q_in => dims. qk; bias, init),
130- Dense (dims. k_in => dims. qk; bias, init),
131- Dense (dims. v_in => dims. v; bias, init))
132- end
133-
134- function (proj:: QKVProj )(q_in, k_in, v_in)
135- return (proj. q_proj (q_in), proj. k_proj (k_in), proj. v_proj (v_in))
136- end
137-
138- function perf (dim, len, batch_size, nheads)
139- mha = MultiHeadAttention (dim, nheads)
140- x = rand (Float32, (dim, len, batch_size))
141-
142- println (" tullio" )
143- @btime $ mha ($ x, impl= :tullio );
144- @btime gradient (m -> sum (m ($ x, impl= :tullio )), $ mha);
145-
146- println (" nalib" )
147- @btime $ mha ($ x, $ x, $ x, impl= :nalib );
148- @btime gradient (m -> sum (m ($ x, impl= :nalib )), $ mha);
149-
150- println (" nnlib" )
151- @btime $ mha ($ x, $ x, $ x, impl= :nnlib );
152- @btime gradient (m -> sum (m ($ x, impl= :nnlib )), $ mha);
153-
154- if CUDA. functional ()
155- mha_gpu = mha |> gpu
156- x_gpu = x |> gpu
157-
158- println (" tullio - gpu" )
159- @btime $ mha_gpu ($ x_gpu, impl= :tullio );
160- @btime gradient (m -> sum (m ($ x_gpu, impl= :tullio )), $ mha_gpu);
161-
162- println (" nalib - gpu" )
163- @btime CUDA. @sync $ mha_gpu ($ x_gpu, impl= :nalib );
164- @btime CUDA. @sync gradient (m -> sum (m ($ x_gpu, impl= :nalib )), $ mha_gpu);
165-
166- println (" nnlib - gpu" )
167- @btime CUDA. @sync $ mha_gpu ($ x_gpu, impl= :nnlib );
168- @btime CUDA. @sync gradient (m -> sum (m ($ x_gpu, impl= :nnlib )), $ mha_gpu);
169- end
170- return nothing
171- end
172-
173- function test (dim, nheads, len, batch_size)
174- mha = MultiHeadAttention (dim, nheads)
175- q = rand (Float32, (dim, len, batch_size))
176- k = rand (Float32, (dim, len, batch_size))
177- v = rand (Float32, (dim, len, batch_size))
178-
179- y, α = mha (q, k, v, impl= :tullio , withscores= true )
180- @test y isa Array{Float32, 3 }
181- @test size (y) == (dim, len, batch_size)
182- @test α isa Array{Float32, 4 }
183- @test size (α) == (len, len, nheads, batch_size)
184-
185- y2, α2 = mha (q, k, v, impl= :nalib , withscores= true )
186- @test size (y) == size (y2)
187- @test y2 ≈ y
188- @test size (α) == size (α2)
189- @test α2 ≈ α
190-
191- y2b, α2b = mha (q, k, v, impl= :nnlib , withscores= true )
192- @test size (y) == size (y2b)
193- @test y2b ≈ y
194- @test size (α) == size (α2b)
195- @test α2b ≈ α
196-
197- mask = make_causal_mask (q)
198- y3, α3 = mha (q, k, v; impl= :tullio , withscores= true , mask)
199- y4, α4 = mha (q, k, v, impl= :nalib , withscores= true , mask= NeuralAttentionlib. CausalMask ())
200- @test y3 ≈ y4
201- @test α3 ≈ α4
202-
203- if CUDA. functional ()
204- mha_gpu = mha |> gpu
205- q_gpu, k_gpu, v_gpu = q |> gpu, k |> gpu, v |> gpu
206-
207- y_gpu = mha_gpu (q_gpu, k_gpu, v_gpu, impl= :tullio )
208- y_gpu2 = mha_gpu (q_gpu, k_gpu, v_gpu, impl= :nalib )
209- @test Array (y_gpu) ≈ Array (y_gpu2)
210- @test Array (y_gpu) ≈ y
211- end
212- return nothing
213- end
214-
215- test (4 , 2 , 3 , 1 )
216-
217- perf (128 , 8 , 128 , 32 )
218-
219- # # M1 Pro, NNlib v0.8.12
220- # tullio
221- # 2.948 ms (77 allocations: 7.25 MiB)
222- # 15.041 ms (1124 allocations: 16.71 MiB)
223- # nalib
224- # 3.503 ms (89 allocations: 7.75 MiB)
225- # 15.828 ms (604 allocations: 14.70 MiB)
226- # nnlib
227- # 3.611 ms (87 allocations: 9.25 MiB)
228- # 16.497 ms (1055 allocations: 20.71 MiB)
229-
230- # # M1 Pro, NNlib v0.8.13 (fast_maximum)
231- # tullio
232- # 2.427 ms (71 allocations: 7.13 MiB)
233- # 14.510 ms (1118 allocations: 16.59 MiB)
234- # nalib
235- # 3.052 ms (84 allocations: 7.63 MiB)
236- # 15.327 ms (599 allocations: 14.57 MiB)
237- # nnlib
238- # 3.166 ms (81 allocations: 9.13 MiB)
239- # 16.082 ms (1049 allocations: 20.58 MiB)
240-
241- # # Threadripper, NNlib v0.8.12
242- # tullio
243- # 5.658 ms (77 allocations: 7.25 MiB)
244- # 22.373 ms (1124 allocations: 16.71 MiB)
245- # nalib
246- # 6.187 ms (89 allocations: 7.75 MiB)
247- # 23.723 ms (604 allocations: 14.70 MiB)
248- # nnlib
249- # 6.473 ms (87 allocations: 9.25 MiB)
250- # 24.966 ms (1055 allocations: 20.71 MiB)
251- # tullio - gpu
252- # 145.332 μs (520 allocations: 24.52 KiB)
253- # 902.020 μs (2221 allocations: 117.19 KiB)
254- # nalib - gpu
255- # 162.354 μs (410 allocations: 18.03 KiB)
256- # 604.111 μs (1263 allocations: 71.78 KiB)
257- # nnlib - gpu
258- # 156.383 μs (440 allocations: 20.00 KiB)
259- # 835.374 μs (1969 allocations: 100.58 KiB)
260-
261- # # Threadripper, NNlib v0.8.13 (fast_maximum)
262- # tullio
263- # 4.599 ms (71 allocations: 7.13 MiB)
264- # 20.699 ms (1118 allocations: 16.59 MiB)
265- # nalib
266- # 5.049 ms (84 allocations: 7.63 MiB)
267- # 22.252 ms (599 allocations: 14.57 MiB)
268- # nnlib
269- # 5.378 ms (81 allocations: 9.13 MiB)
270- # 23.453 ms (1049 allocations: 20.58 MiB)
271- # tullio - gpu
272- # 145.824 μs (520 allocations: 24.52 KiB)
273- # 915.305 μs (2221 allocations: 117.19 KiB)
274- # nalib - gpu
275- # 164.789 μs (410 allocations: 18.03 KiB)
276- # 610.835 μs (1263 allocations: 71.78 KiB)
277- # nnlib - gpu
278- # 157.785 μs (440 allocations: 20.00 KiB)
279- # 852.087 μs (1969 allocations: 100.58 KiB)
280-
281-
282- # function prof()
283- # dim, len, batch_size, nheads = 128, 8, 128, 32;
284- # # dim = 384; len = 128; batch_size = 32; nheads = 12
285- # mha = MultiHeadAttention(dim, nheads)
286- # x = rand(Float32, (dim, len, batch_size))
287- # @btime mha(x, impl=:tullio);
288- # @btime mha(x, impl=:nnlib);
289- # @profview mha(x, impl=:tullio);
290- # @profview prof(mha, x);
291- # y, α = mha(x; impl=:nnlib, withscores=true, mask)
292- # y2, α2 = mha(x; impl=:nalib, withscores=true, mask=NeuralAttentionlib.CausalMask())
293- # end
0 commit comments