33using Printf
44using Logging
55using TimerOutputs
6+ using DataStructures
67
78include (" pool/utils.jl" )
89using . PoolUtils
6465const usage_limit = PerDevice {Int} () do dev
6566 if haskey (ENV , " JULIA_CUDA_MEMORY_LIMIT" )
6667 parse (Int, ENV [" JULIA_CUDA_MEMORY_LIMIT" ])
67- elseif haskey (ENV , " CUARRAYS_MEMORY_LIMIT" )
68- Base. depwarn (" The CUARRAYS_MEMORY_LIMIT environment flag is deprecated, please use JULIA_CUDA_MEMORY_LIMIT instead." , :__init_pool__ )
69- parse (Int, ENV [" CUARRAYS_MEMORY_LIMIT" ])
7068 else
7169 typemax (Int)
7270 end
@@ -116,7 +114,8 @@ function hard_limit(dev::CuDevice)
116114 usage_limit[dev]
117115end
118116
119- function actual_alloc (dev:: CuDevice , bytes:: Integer , last_resort:: Bool = false )
117+ function actual_alloc (dev:: CuDevice , bytes:: Integer , last_resort:: Bool = false ;
118+ stream_ordered:: Bool = false )
120119 buf = @device! dev begin
121120 # check the memory allocation limit
122121 if usage[dev][] + bytes > (last_resort ? hard_limit (dev) : soft_limit (dev))
@@ -127,7 +126,7 @@ function actual_alloc(dev::CuDevice, bytes::Integer, last_resort::Bool=false)
127126 try
128127 time = Base. @elapsed begin
129128 @timeit_debug alloc_to " alloc" begin
130- buf = Mem. alloc (Mem. Device, bytes; async= true )
129+ buf = Mem. alloc (Mem. Device, bytes; async= true , stream_ordered )
131130 end
132131 end
133132
@@ -146,7 +145,7 @@ function actual_alloc(dev::CuDevice, bytes::Integer, last_resort::Bool=false)
146145 return Block (buf, bytes; state= AVAILABLE)
147146end
148147
149- function actual_free (dev:: CuDevice , block:: Block )
148+ function actual_free (dev:: CuDevice , block:: Block ; stream_ordered :: Bool = false )
150149 @assert iswhole (block) " Cannot free $block : block is not whole"
151150 @assert block. off == 0
152151 @assert block. state == AVAILABLE " Cannot free $block : block is not available"
@@ -155,7 +154,7 @@ function actual_free(dev::CuDevice, block::Block)
155154 # free the memory
156155 @timeit_debug alloc_to " free" begin
157156 time = Base. @elapsed begin
158- Mem. free (block. buf; async= true )
157+ Mem. free (block. buf; async= true , stream_ordered )
159158 end
160159 block. state = INVALID
161160
@@ -181,41 +180,42 @@ Show the timings of the currently active memory pool. Assumes
181180pool_timings () = (show (PoolUtils. to; allocations= false , sortby= :name ); println ())
182181
183182# pool API:
184- # - init()
185- # - alloc(::CuDevice , sz)::Block
186- # - free(::CuDevice , ::Block)
187- # - reclaim(::CuDevice , nb::Int=typemax(Int))::Int
188- # - cached_memory()
183+ # - constructor taking a CuDevice
184+ # - alloc(::AbstractPool , sz)::Block
185+ # - free(::AbstractPool , ::Block)
186+ # - reclaim(::AbstractPool , nb::Int=typemax(Int))::Int
187+ # - cached_memory(::AbstractPool )
189188
190189module Pool
191190@enum MemoryPool None Simple Binned Split
192191end
193- const active_pool = Ref {Pool.MemoryPool} ()
194- const async_alloc = Ref {Bool} ()
195-
196- macro pooled (ex)
197- @assert Meta. isexpr (ex, :call )
198- f, args... = ex. args
199- quote
200- if active_pool[] == Pool. None
201- NoPool.$ (f)($ (map (esc, args)... ))
202- elseif active_pool[] == Pool. Simple
203- SimplePool.$ (f)($ (map (esc, args)... ))
204- elseif active_pool[] == Pool. Binned
205- BinnedPool.$ (f)($ (map (esc, args)... ))
206- elseif active_pool[] == Pool. Split
207- SplitPool.$ (f)($ (map (esc, args)... ))
208- else
209- error (" unreachable" )
210- end
211- end
212- end
213192
193+ abstract type AbstractPool end
214194include (" pool/none.jl" )
215195include (" pool/simple.jl" )
216196include (" pool/binned.jl" )
217197include (" pool/split.jl" )
218198
199+ const pools = PerDevice {AbstractPool} (dev-> begin
200+ default_pool = version () >= v " 11.2" ? " cuda" : " binned"
201+ pool_name = get (ENV , " JULIA_CUDA_MEMORY_POOL" , default_pool)
202+ pool = if pool_name == " none"
203+ NoPool (; dev, stream_ordered= false )
204+ elseif pool_name == " simple"
205+ SimplePool (; dev, stream_ordered= false )
206+ elseif pool_name == " binned"
207+ BinnedPool (; dev, stream_ordered= false )
208+ elseif pool_name == " split"
209+ SplitPool (; dev, stream_ordered= false )
210+ elseif pool_name == " cuda"
211+ @assert version () >= v " 11.2" " The CUDA memory pool is only supported on CUDA 11.2+"
212+ NoPool (; dev, stream_ordered= true )
213+ else
214+ error (" Invalid memory pool '$pool_name '" )
215+ end
216+ pool
217+ end )
218+
219219
220220# # interface
221221
@@ -263,11 +263,11 @@ a [`OutOfGPUMemoryError`](@ref) if the allocation request cannot be satisfied.
263263 sz == 0 && return CU_NULL
264264
265265 dev = device ()
266+ pool = pools[dev]
266267
267268 time = Base. @elapsed begin
268- @pool_timeit " pooled alloc" block = @pooled alloc (dev , sz)
269+ @pool_timeit " pooled alloc" block = alloc (pool , sz):: Union{Nothing,Block}
269270 end
270- block:: Union{Nothing,Block}
271271 block === nothing && throw (OutOfGPUMemoryError (sz))
272272
273273 # record the memory block
@@ -328,6 +328,7 @@ Releases a buffer pointed to by `ptr` to the memory pool.
328328 ptr == CU_NULL && return
329329
330330 dev = device ()
331+ pool = pools[dev]
331332 last_use[dev] = time ()
332333
333334 if MEMDEBUG && ptr == CuPtr {Cvoid} (0xbbbbbbbbbbbbbbbb )
@@ -359,7 +360,7 @@ Releases a buffer pointed to by `ptr` to the memory pool.
359360 end
360361
361362 time = Base. @elapsed begin
362- @pool_timeit " pooled free" @pooled free (dev , block)
363+ @pool_timeit " pooled free" free (pool , block)
363364 end
364365
365366 alloc_stats. pool_time += time
@@ -382,7 +383,8 @@ actually reclaimed.
382383"""
383384function reclaim (sz:: Int = typemax (Int))
384385 dev = device ()
385- @pooled reclaim (dev, sz)
386+ pool = pools[dev]
387+ reclaim (pool, sz)
386388end
387389
388390"""
@@ -403,6 +405,9 @@ macro retry_reclaim(isfailed, ex)
403405 ret = $ (esc (ex))
404406 $ (esc (isfailed))(ret) || break
405407
408+ dev = device ()
409+ pool = pools[dev]
410+
406411 # incrementally more costly reclaim of cached memory
407412 if phase == 1
408413 reclaim ()
@@ -412,11 +417,10 @@ macro retry_reclaim(isfailed, ex)
412417 elseif phase == 3
413418 GC. gc (true )
414419 reclaim ()
415- elseif phase == 4 && async_alloc[]
420+ elseif phase == 4 && pool . stream_ordered
416421 # this phase is unique to retry_reclaim, as regular allocations come from the pool
417422 # so are assumed to never need to trim its contents.
418- pool = memory_pool (device ())
419- trim (pool)
423+ trim (memory_pool (device ()))
420424 end
421425 end
422426 ret
@@ -445,7 +449,8 @@ function pool_cleanup()
445449
446450 if t1- t0 > 300
447451 # the pool hasn't been used for a while, so reclaim unused buffers
448- @pooled reclaim (dev)
452+ pool = pools[dev]
453+ reclaim (pool)
449454 end
450455 end
451456
@@ -561,7 +566,10 @@ macro timed(ex)
561566 end
562567end
563568
564- cached_memory () = @pooled cached_memory ()
569+ function cached_memory (dev:: CuDevice = device ())
570+ pool = pools[dev]
571+ cached_memory (pool)
572+ end
565573
566574"""
567575 memory_status([io=stdout])
@@ -584,10 +592,11 @@ function memory_status(io::IO=stdout)
584592 end
585593 println (io)
586594
587- alloc_used_bytes = used_memory ()
588- alloc_cached_bytes = cached_memory ()
595+ pool = pools[dev]
596+ alloc_used_bytes = used_memory (dev)
597+ alloc_cached_bytes = cached_memory (pool)
589598 alloc_total_bytes = alloc_used_bytes + alloc_cached_bytes
590- @printf (io, " Memory pool '%s' usage: %s (%s allocated, %s cached)\n " , string (active_pool[] ),
599+ @printf (io, " Memory pool '%s' usage: %s (%s allocated, %s cached)\n " , string (pool ),
591600 Base. format_bytes (alloc_total_bytes), Base. format_bytes (alloc_used_bytes),
592601 Base. format_bytes (alloc_cached_bytes))
593602
@@ -627,24 +636,8 @@ function __init_pool__()
627636 initialize! (allocated, ndevices ())
628637 initialize! (requested, ndevices ())
629638
630- # memory pool configuration
631- default_pool = version () >= v " 11.2" ? " cuda" : " binned"
632- pool_name = get (ENV , " JULIA_CUDA_MEMORY_POOL" , default_pool)
633- active_pool[], async_alloc[] = if pool_name == " none"
634- Pool. None, false
635- elseif pool_name == " simple"
636- Pool. Simple, false
637- elseif pool_name == " binned"
638- Pool. Binned, false
639- elseif pool_name == " split"
640- Pool. Split, false
641- elseif pool_name == " cuda"
642- @assert version () >= v " 11.2" " The CUDA memory pool is only supported on CUDA 11.2+"
643- Pool. None, true
644- else
645- error (" Invalid memory pool '$pool_name '" )
646- end
647- @pooled init ()
639+ # memory pools
640+ initialize! (pools, ndevices ())
648641
649642 TimerOutputs. reset_timer! (alloc_to)
650643 TimerOutputs. reset_timer! (PoolUtils. to)
@@ -660,6 +653,6 @@ function __init_pool__()
660653 end
661654
662655 if isinteractive ()
663- @async @pooled pool_cleanup ()
656+ @async pool_cleanup ()
664657 end
665658end
0 commit comments