Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit e32ff30

Browse files
committed
Fix bump allocator bug
1 parent 29d30f9 commit e32ff30

File tree

1 file changed

+8
-14
lines changed

1 file changed

+8
-14
lines changed

src/device/runtime.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ function gc_pool_alloc(sz::Csize_t)
160160
@cuprintf("ERROR: Out of dynamic GPU memory (trying to allocate %i bytes)\n", sz)
161161
throw(OutOfMemoryError())
162162
end
163-
return
163+
return unsafe_pointer_to_objref(ptr)
164164
end
165165

166166
compile(gc_pool_alloc, Any, (Csize_t,), T_prjlvalue)
@@ -257,7 +257,7 @@ end
257257
# Gets a pointer to a global with a particular name. If the global
258258
# does not exist yet, then it is declared in the global memory address
259259
# space.
260-
@generated function get_global_pointer(::Val{global_name}, ::Type{T})::Ptr{T} where {global_name, T}
260+
@generated function get_global_pointer(::Val{global_name}, ::Type{T})::CUDAnative.DevicePtr{T} where {global_name, T}
261261
T_global = convert(LLVMType, T)
262262
T_result = convert(LLVMType, Ptr{T})
263263

@@ -289,23 +289,17 @@ end
289289
end
290290

291291
# Call the function.
292-
call_function(llvm_f, Ptr{T})
293-
end
294-
295-
macro cuda_global_ptr(name, type)
296-
return :(convert(
297-
DevicePtr{T},
298-
get_global_pointer(
299-
$(Val(Symbol(name))),
300-
$(esc(type)))))
292+
quote
293+
CUDAnative.DevicePtr{T, CUDAnative.AS.Generic}(convert(Csize_t, $(call_function(llvm_f, Ptr{T}))))
294+
end
301295
end
302296

303297
# Allocates `bytesize` bytes of storage by bumping the global bump
304298
# allocator pointer.
305299
function bump_alloc(bytesize::Csize_t)::Ptr{UInt8}
306-
ptr = @cuda_global_ptr("bump_alloc_ptr", Csize_t)
307-
chunk_address = atomic_add!(ptr, bytesize)
308-
end_ptr = unsafe_load(@cuda_global_ptr("bump_alloc_end", Csize_t))
300+
ptr = get_global_pointer(Val(:bump_alloc_ptr), Csize_t)
301+
chunk_address = CUDAnative.atomic_add!(ptr, bytesize)
302+
end_ptr = unsafe_load(get_global_pointer(Val(:bump_alloc_end), Csize_t))
309303
if chunk_address < end_ptr
310304
return convert(Ptr{UInt8}, chunk_address)
311305
else

0 commit comments

Comments
 (0)