11# reference implementation on the CPU
2-
3- # note that most of the code in this file serves to define a functional array type,
4- # the actual implementation of GPUArrays-interfaces is much more limited.
2+ # This acts as a wrapper around KernelAbstractions's parallel CPU
3+ # functionality. It is useful for testing GPUArrays (and other packages)
4+ # when no GPU is present.
5+ # This file follows conventions from AMDGPU.jl
56
67module JLArrays
78
8- export JLArray, JLVector, JLMatrix, jl
9-
109using GPUArrays
11-
1210using Adapt
11+ import KernelAbstractions
12+ import KernelAbstractions: Adapt, StaticArrays, Backend, Kernel, StaticSize, DynamicSize, partition, blocks, workitems, launch_config
1313
14+ export JLArray, JLVector, JLMatrix, jl, JLBackend
1415
1516#
1617# Device functionality
1718#
1819
1920const MAXTHREADS = 256
2021
21-
22- # # execution
23-
24- struct JLBackend <: AbstractGPUBackend end
25-
26- mutable struct JLKernelContext <: AbstractKernelContext
27- blockdim:: Int
28- griddim:: Int
29- blockidx:: Int
30- threadidx:: Int
31-
32- localmem_counter:: Int
33- localmems:: Vector{Vector{Array}}
34- end
35-
36- function JLKernelContext (threads:: Int , blockdim:: Int )
37- blockcount = prod (blockdim)
38- lmems = [Vector {Array} () for i in 1 : blockcount]
39- JLKernelContext (threads, blockdim, 1 , 1 , 0 , lmems)
22+ struct JLBackend <: KernelAbstractions.GPU
23+ static:: Bool
24+ JLBackend (;static:: Bool = false ) = new (static)
4025end
4126
42- function JLKernelContext (ctx:: JLKernelContext , threadidx:: Int )
43- JLKernelContext (
44- ctx. blockdim,
45- ctx. griddim,
46- ctx. blockidx,
47- threadidx,
48- 0 ,
49- ctx. localmems
50- )
51- end
5227
5328struct Adaptor end
5429jlconvert (arg) = adapt (Adaptor (), arg)
6035Base. getindex (r:: JlRefValue ) = r. x
6136Adapt. adapt_structure (to:: Adaptor , r:: Base.RefValue ) = JlRefValue (adapt (to, r[]))
6237
63- function GPUArrays. gpu_call (:: JLBackend , f, args, threads:: Int , blocks:: Int ;
64- name:: Union{String,Nothing} )
65- ctx = JLKernelContext (threads, blocks)
66- device_args = jlconvert .(args)
67- tasks = Array {Task} (undef, threads)
68- for blockidx in 1 : blocks
69- ctx. blockidx = blockidx
70- for threadidx in 1 : threads
71- thread_ctx = JLKernelContext (ctx, threadidx)
72- tasks[threadidx] = @async f (thread_ctx, device_args... )
73- # TODO : require 1.3 and use Base.Threads.@spawn for actual multithreading
74- # (this would require a different synchronization mechanism)
75- end
76- for t in tasks
77- fetch (t)
78- end
38+ mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
39+ data:: DataRef{Vector{UInt8}}
40+
41+ offset:: Int # offset of the data in the buffer, in number of elements
42+
43+ dims:: Dims{N}
44+
45+ # allocating constructor
46+ function JLArray {T,N} (:: UndefInitializer , dims:: Dims{N} ) where {T,N}
47+ check_eltype (T)
48+ maxsize = prod (dims) * sizeof (T)
49+ data = Vector {UInt8} (undef, maxsize)
50+ ref = DataRef (data)
51+ obj = new {T,N} (ref, 0 , dims)
52+ finalizer (unsafe_free!, obj)
7953 end
80- return
81- end
8254
55+ # low-level constructor for wrapping existing data
56+ function JLArray {T,N} (ref:: DataRef{Vector{UInt8}} , dims:: Dims{N} ;
57+ offset:: Int = 0 ) where {T,N}
58+ check_eltype (T)
59+ obj = new {T,N} (ref, offset, dims)
60+ finalizer (unsafe_free!, obj)
61+ end
62+ end
8363
84- # # executed on-device
64+ Adapt. adapt_storage (:: JLBackend , a:: Array ) = Adapt. adapt (JLArrays. JLArray, a)
65+ Adapt. adapt_storage (:: JLBackend , a:: JLArrays.JLArray ) = a
66+ Adapt. adapt_storage (:: KernelAbstractions.CPU , a:: JLArrays.JLArray ) = convert (Array, a)
8567
8668# array type
8769
10789@inline Base. getindex (A:: JLDeviceArray , index:: Integer ) = getindex (typed_data (A), index)
10890@inline Base. setindex! (A:: JLDeviceArray , x, index:: Integer ) = setindex! (typed_data (A), x, index)
10991
110-
111- # indexing
112-
113- for f in (:blockidx , :blockdim , :threadidx , :griddim )
114- @eval GPUArrays.$ f (ctx:: JLKernelContext ) = ctx.$ f
115- end
116-
117- # memory
118-
119- function GPUArrays. LocalMemory (ctx:: JLKernelContext , :: Type{T} , :: Val{dims} , :: Val{id} ) where {T, dims, id}
120- ctx. localmem_counter += 1
121- lmems = ctx. localmems[blockidx (ctx)]
122-
123- # first invocation in block
124- data = if length (lmems) < ctx. localmem_counter
125- lmem = fill (zero (T), dims)
126- push! (lmems, lmem)
127- lmem
128- else
129- lmems[ctx. localmem_counter]
130- end
131-
132- N = length (dims)
133- JLDeviceArray {T,N} (data, tuple (dims... ))
134- end
135-
136- # synchronization
137-
138- @inline function GPUArrays. synchronize_threads (:: JLKernelContext )
139- # All threads are getting started asynchronously, so a yield will yield to the next
140- # execution of the same function, which should call yield at the exact same point in the
141- # program, leading to a chain of yields effectively syncing the tasks (threads).
142- yield ()
143- return
144- end
145-
146-
14792#
14893# Host abstractions
14994#
@@ -157,32 +102,6 @@ function check_eltype(T)
157102 end
158103end
159104
160- mutable struct JLArray{T, N} <: AbstractGPUArray{T, N}
161- data:: DataRef{Vector{UInt8}}
162-
163- offset:: Int # offset of the data in the buffer, in number of elements
164-
165- dims:: Dims{N}
166-
167- # allocating constructor
168- function JLArray {T,N} (:: UndefInitializer , dims:: Dims{N} ) where {T,N}
169- check_eltype (T)
170- maxsize = prod (dims) * sizeof (T)
171- data = Vector {UInt8} (undef, maxsize)
172- ref = DataRef (data)
173- obj = new {T,N} (ref, 0 , dims)
174- finalizer (unsafe_free!, obj)
175- end
176-
177- # low-level constructor for wrapping existing data
178- function JLArray {T,N} (ref:: DataRef{Vector{UInt8}} , dims:: Dims{N} ;
179- offset:: Int = 0 ) where {T,N}
180- check_eltype (T)
181- obj = new {T,N} (ref, offset, dims)
182- finalizer (unsafe_free!, obj)
183- end
184- end
185-
186105unsafe_free! (a:: JLArray ) = GPUArrays. unsafe_free! (a. data)
187106
188107# conversion of untyped data to a typed Array
392311
393312# # GPUArrays interfaces
394313
395- GPUArrays. backend (:: Type{<:JLArray} ) = JLBackend ()
396-
397314Adapt. adapt_storage (:: Adaptor , x:: JLArray{T,N} ) where {T,N} =
398315 JLDeviceArray {T,N} (x. data[], x. offset, x. dims)
399316
@@ -406,4 +323,47 @@ function GPUArrays.mapreducedim!(f, op, R::AnyJLArray, A::Union{AbstractArray,Br
406323 R
407324end
408325
326+ # # KernelAbstractions interface
327+
328+ KernelAbstractions. get_backend (a:: JLA ) where JLA <: JLArray = JLBackend ()
329+
330+ function KernelAbstractions. mkcontext (kernel:: Kernel{JLBackend} , I, _ndrange, iterspace, :: Dynamic ) where Dynamic
331+ return KernelAbstractions. CompilerMetadata {KernelAbstractions.ndrange(kernel), Dynamic} (I, _ndrange, iterspace)
332+ end
333+
334+ KernelAbstractions. allocate (:: JLBackend , :: Type{T} , dims:: Tuple ) where T = JLArray {T} (undef, dims)
335+
336+ @inline function launch_config (kernel:: Kernel{JLBackend} , ndrange, workgroupsize)
337+ if ndrange isa Integer
338+ ndrange = (ndrange,)
339+ end
340+ if workgroupsize isa Integer
341+ workgroupsize = (workgroupsize, )
342+ end
343+
344+ if KernelAbstractions. workgroupsize (kernel) <: DynamicSize && workgroupsize === nothing
345+ workgroupsize = (1024 ,) # Vectorization, 4x unrolling, minimal grain size
346+ end
347+ iterspace, dynamic = partition (kernel, ndrange, workgroupsize)
348+ # partition checked that the ndrange's agreed
349+ if KernelAbstractions. ndrange (kernel) <: StaticSize
350+ ndrange = nothing
351+ end
352+
353+ return ndrange, workgroupsize, iterspace, dynamic
354+ end
355+
356+ KernelAbstractions. isgpu (b:: JLBackend ) = false
357+
358+ function convert_to_cpu (obj:: Kernel{JLBackend, W, N, F} ) where {W, N, F}
359+ return Kernel {typeof(KernelAbstractions.CPU(; static = obj.backend.static)), W, N, F} (KernelAbstractions. CPU (; static = obj. backend. static), obj. f)
360+ end
361+
362+ function (obj:: Kernel{JLBackend} )(args... ; ndrange= nothing , workgroupsize= nothing )
363+ device_args = jlconvert .(args)
364+ new_obj = convert_to_cpu (obj)
365+ new_obj (device_args... ; ndrange, workgroupsize)
366+
367+ end
368+
409369end
0 commit comments