Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit 3972441

Browse files
authored
Merge pull request #416 from JuliaGPU/tb/launch_config
Support dynamic launch configuration.
2 parents 5c9d422 + a98992c commit 3972441

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

src/execution.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ export @cuda, cudaconvert, cufunction, dynamic_cufunction, nearest_warpsize
1010
function split_kwargs(kwargs)
1111
macro_kws = [:dynamic]
1212
compiler_kws = [:minthreads, :maxthreads, :blocks_per_sm, :maxregs, :name]
13-
call_kws = [:cooperative, :blocks, :threads, :shmem, :stream]
13+
call_kws = [:cooperative, :blocks, :threads, :config, :shmem, :stream]
1414
macro_kwargs = []
1515
compiler_kwargs = []
1616
call_kwargs = []
@@ -226,6 +226,9 @@ The following keyword arguments are supported:
226226
- `threads` (defaults to 1)
227227
- `blocks` (defaults to 1)
228228
- `shmem` (defaults to 0)
229+
- `config`: callback function to dynamically compute the launch configuration.
230+
should accept a `HostKernel` and return a name tuple with any of the above as fields.
231+
this functionality is intended to be used in combination with the CUDA occupancy API.
229232
- `stream` (defaults to the default stream)
230233
"""
231234
AbstractKernel
@@ -269,8 +272,13 @@ end
269272

270273
@doc (@doc AbstractKernel) HostKernel
271274

272-
@inline cudacall(kernel::HostKernel, tt, args...; kwargs...) =
273-
CUDAdrv.cudacall(kernel.fun, tt, args...; kwargs...)
275+
@inline function cudacall(kernel::HostKernel, tt, args...; config=nothing, kwargs...)
276+
if config !== nothing
277+
CUDAdrv.cudacall(kernel.fun, tt, args...; kwargs..., config(kernel)...)
278+
else
279+
CUDAdrv.cudacall(kernel.fun, tt, args...; kwargs...)
280+
end
281+
end
274282

275283
"""
276284
version(k::HostKernel)

test/device/execution.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,24 @@ dummy() = return
2222
end
2323

2424

25+
@testset "launch configuration" begin
26+
@cuda dummy()
27+
28+
@cuda threads=1 dummy()
29+
@cuda threads=(1,1) dummy()
30+
@cuda threads=(1,1,1) dummy()
31+
32+
@cuda blocks=1 dummy()
33+
@cuda blocks=(1,1) dummy()
34+
@cuda blocks=(1,1,1) dummy()
35+
36+
@cuda config=(kernel)->() dummy()
37+
@cuda config=(kernel)->(threads=1,) dummy()
38+
@cuda config=(kernel)->(blocks=1,) dummy()
39+
@cuda config=(kernel)->(shmem=0,) dummy()
40+
end
41+
42+
2543
@testset "compilation params" begin
2644
@cuda dummy()
2745

0 commit comments

Comments
 (0)