Skip to content

Commit 00e703d

Browse files
Add GPU support for Complex atomics and refactor to shared utilities
- Extract shared Complex utilities to src/complex.jl - Add ComplexF32/ComplexF64 support to CUDA extension (via CAS loops) - Add ComplexF32 support to Metal extension (ComplexF64 not supported due to 128-bit limit) - Refactor CPU implementation to use shared utilities - Eliminates code duplication across backends
1 parent 9abbe99 commit 00e703d

File tree

5 files changed

+139
-73
lines changed

5 files changed

+139
-73
lines changed

ext/AtomixCUDAExt.jl

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
module AtomixCUDAExt
33

44
using Atomix: Atomix, IndexableRef
5+
using Atomix.Internal: _int_type_for_complex, _to_int, _from_int, _with_int_repr, _generic_cas_loop_modify!
56
using CUDA: CUDA, CuDeviceArray
67

78
const CuIndexableRef{Indexable<:CuDeviceArray} = IndexableRef{Indexable}
@@ -24,35 +25,57 @@ end
2425
ptr = Atomix.pointer(ref)
2526
expected = convert(eltype(ref), expected)
2627
desired = convert(eltype(ref), desired)
27-
begin
28-
old = CUDA.atomic_cas!(ptr, expected, desired)
29-
end
30-
return (; old = old, success = old === expected)
28+
_cuda_cas!(ptr, expected, desired)
29+
end
30+
31+
# CUDA CAS - with Complex support via integer reinterpretation
32+
@inline function _cuda_cas!(ptr::Ptr{T}, expected::T, desired::T) where {T}
33+
result = _with_int_repr(CUDA.atomic_cas!, ptr, expected, desired)
34+
return _from_int(T, (; old = result, success = result === expected))
35+
end
36+
37+
# CUDA load via CAS (for Complex support)
38+
@inline function _cuda_load(ptr::Ptr{T}) where {T}
39+
_with_int_repr(_cuda_load_impl, ptr)
40+
end
41+
42+
@inline function _cuda_load_impl(ptr::Ptr{T}) where {T}
43+
# Load via CAS with same value
44+
old = unsafe_load(ptr)
45+
CUDA.atomic_cas!(ptr, old, old)
3146
end
3247

3348
@inline function Atomix.modify!(ref::CuIndexableRef, op::OP, x, order) where {OP}
3449
x = convert(eltype(ref), x)
3550
ptr = Atomix.pointer(ref)
36-
begin
37-
old = if op === (+)
38-
CUDA.atomic_add!(ptr, x)
39-
elseif op === (-)
40-
CUDA.atomic_sub!(ptr, x)
41-
elseif op === (&)
42-
CUDA.atomic_and!(ptr, x)
43-
elseif op === (|)
44-
CUDA.atomic_or!(ptr, x)
45-
elseif op === xor
46-
CUDA.atomic_xor!(ptr, x)
47-
elseif op === min
48-
CUDA.atomic_min!(ptr, x)
49-
elseif op === max
50-
CUDA.atomic_max!(ptr, x)
51-
else
52-
error("not implemented")
53-
end
51+
_cuda_modify!(ptr, op, x)
52+
end
53+
54+
# CUDA modify! - native operations for non-Complex types
55+
@inline function _cuda_modify!(ptr::Ptr{T}, op::OP, x::T) where {T,OP}
56+
old = if op === (+)
57+
CUDA.atomic_add!(ptr, x)
58+
elseif op === (-)
59+
CUDA.atomic_sub!(ptr, x)
60+
elseif op === (&)
61+
CUDA.atomic_and!(ptr, x)
62+
elseif op === (|)
63+
CUDA.atomic_or!(ptr, x)
64+
elseif op === xor
65+
CUDA.atomic_xor!(ptr, x)
66+
elseif op === min
67+
CUDA.atomic_min!(ptr, x)
68+
elseif op === max
69+
CUDA.atomic_max!(ptr, x)
70+
else
71+
error("not implemented")
5472
end
5573
return old => op(old, x)
5674
end
5775

76+
# CUDA modify! - CAS loop for Complex types
77+
@inline function _cuda_modify!(ptr::Ptr{Complex{T}}, op::OP, x::Complex{T}) where {T,OP}
78+
_generic_cas_loop_modify!(_cuda_cas!, _cuda_load, ptr, op, x, nothing)
79+
end
80+
5881
end # module AtomixCUDAExt

ext/AtomixMetalExt.jl

Lines changed: 57 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
module AtomixMetalExt
33

44
using Atomix: Atomix, IndexableRef
5+
using Atomix.Internal: _int_type_for_complex, _to_int, _from_int, _with_int_repr, _generic_cas_loop_modify!
56
using Metal: Metal, MtlDeviceArray
67

78
const MtlIndexableRef{Indexable<:MtlDeviceArray} = IndexableRef{Indexable}
@@ -24,12 +25,39 @@ end
2425
ptr = Atomix.pointer(ref)
2526
expected = convert(eltype(ref), expected)
2627
desired = convert(eltype(ref), desired)
27-
begin
28-
old = Metal.atomic_compare_exchange_weak_explicit(ptr, expected, desired)
29-
end
28+
_metal_cas!(ptr, expected, desired)
29+
end
30+
31+
# Metal CAS - with ComplexF32 support (ComplexF64 not supported due to 128-bit limitation)
32+
@inline function _metal_cas!(ptr::Ptr{ComplexF32}, expected::ComplexF32, desired::ComplexF32)
33+
IntType = UInt64
34+
int_ptr = reinterpret(Ptr{IntType}, ptr)
35+
int_expected = reinterpret(IntType, expected)
36+
int_desired = reinterpret(IntType, desired)
37+
old = Metal.atomic_compare_exchange_weak_explicit(int_ptr, int_expected, int_desired)
38+
old_complex = reinterpret(ComplexF32, old)
39+
return (; old = old_complex, success = old_complex === expected)
40+
end
41+
42+
@inline function _metal_cas!(ptr::Ptr{T}, expected::T, desired::T) where {T}
43+
old = Metal.atomic_compare_exchange_weak_explicit(ptr, expected, desired)
3044
return (; old = old, success = old === expected)
3145
end
3246

47+
# Metal load via CAS (for Complex support)
48+
@inline function _metal_load(ptr::Ptr{ComplexF32})
49+
IntType = UInt64
50+
int_ptr = reinterpret(Ptr{IntType}, ptr)
51+
old = unsafe_load(int_ptr)
52+
result = Metal.atomic_compare_exchange_weak_explicit(int_ptr, old, old)
53+
return reinterpret(ComplexF32, result)
54+
end
55+
56+
@inline function _metal_load(ptr::Ptr{T}) where {T}
57+
old = unsafe_load(ptr)
58+
Metal.atomic_compare_exchange_weak_explicit(ptr, old, old)
59+
end
60+
3361

3462
# CAS is needed for FP ops on ThreadGroup memory
3563
@inline function Atomix.modify!(ref::IndexableRef{<:MtlDeviceArray{<:AbstractFloat, <:Any, Metal.AS.ThreadGroup}} , op::OP, x, order) where {OP}
@@ -42,26 +70,34 @@ end
4270
@inline function Atomix.modify!(ref::MtlIndexableRef, op::OP, x, order) where {OP}
4371
x = convert(eltype(ref), x)
4472
ptr = Atomix.pointer(ref)
45-
begin
46-
old = if op === (+)
47-
Metal.atomic_fetch_add_explicit(ptr, x)
48-
elseif op === (-)
49-
Metal.atomic_fetch_sub_explicit(ptr, x)
50-
elseif op === (&)
51-
Metal.atomic_fetch_and_explicit(ptr, x)
52-
elseif op === (|)
53-
Metal.atomic_fetch_or_explicit(ptr, x)
54-
elseif op === xor
55-
Metal.atomic_fetch_xor_explicit(ptr, x)
56-
elseif op === min
57-
Metal.atomic_fetch_min_explicit(ptr, x)
58-
elseif op === max
59-
Metal.atomic_fetch_max_explicit(ptr, x)
60-
else
61-
error("not implemented")
62-
end
73+
_metal_modify!(ptr, op, x)
74+
end
75+
76+
# Metal modify! - native operations for non-Complex types
77+
@inline function _metal_modify!(ptr::Ptr{T}, op::OP, x::T) where {T,OP}
78+
old = if op === (+)
79+
Metal.atomic_fetch_add_explicit(ptr, x)
80+
elseif op === (-)
81+
Metal.atomic_fetch_sub_explicit(ptr, x)
82+
elseif op === (&)
83+
Metal.atomic_fetch_and_explicit(ptr, x)
84+
elseif op === (|)
85+
Metal.atomic_fetch_or_explicit(ptr, x)
86+
elseif op === xor
87+
Metal.atomic_fetch_xor_explicit(ptr, x)
88+
elseif op === min
89+
Metal.atomic_fetch_min_explicit(ptr, x)
90+
elseif op === max
91+
Metal.atomic_fetch_max_explicit(ptr, x)
92+
else
93+
error("not implemented")
6394
end
6495
return old => op(old, x)
6596
end
6697

98+
# Metal modify! - CAS loop for ComplexF32 only (ComplexF64 not supported)
99+
@inline function _metal_modify!(ptr::Ptr{ComplexF32}, op::OP, x::ComplexF32) where {OP}
100+
_generic_cas_loop_modify!(_metal_cas!, _metal_load, ptr, op, x, nothing)
101+
end
102+
67103
end # module AtomixMetalExt

src/Atomix.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ using UnsafeAtomics:
3232

3333
include("utils.jl")
3434
include("references.jl")
35+
include("complex.jl")
3536
include("generic.jl")
3637
include("core.jl")
3738
include("sugar.jl")

src/complex.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Shared utilities for Complex number atomic operations across backends
2+
3+
# Integer type mapping for Complex types
4+
_int_type_for_complex(::Type{Float32}) = UInt64
5+
_int_type_for_complex(::Type{Float64}) = UInt128
6+
7+
# Convert to/from integer representation
8+
_to_int(::Type{I}, x::Complex{T}) where {I,T} = reinterpret(I, x)
9+
_to_int(::Type{I}, x) where {I} = x
10+
11+
_from_int(::Type{Complex{T}}, x::Integer) where {T} = reinterpret(Complex{T}, x)
12+
_from_int(::Type{Complex{T}}, result::NamedTuple) where {T} =
13+
(old = reinterpret(Complex{T}, result.old), success = result.success)
14+
_from_int(::Type{T}, x) where {T} = x
15+
16+
# Helper: apply atomic operation with integer reinterpretation for Complex types
17+
function _with_int_repr(f, ptr::Ptr{Complex{T}}, args...) where {T}
18+
IntType = _int_type_for_complex(T)
19+
int_ptr = reinterpret(Ptr{IntType}, ptr)
20+
result = f(int_ptr, _to_int.(IntType, args)...)
21+
return _from_int(Complex{T}, result)
22+
end
23+
24+
_with_int_repr(f, ptr::Ptr{T}, args...) where {T} = f(ptr, args...)
25+
26+
# Generic CAS loop for modify! operations (used when native atomics unavailable)
27+
function _generic_cas_loop_modify!(cas_fn, load_fn, ptr::Ptr{T}, op::OP, x::T, ord) where {T,OP}
28+
old = load_fn(ptr, ord)
29+
while true
30+
new = op(old, x)
31+
result = cas_fn(ptr, old, new, ord, ord)
32+
result.success && return (old => new)
33+
old = result.old
34+
end
35+
end

src/core.jl

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -53,36 +53,7 @@ end
5353

5454
# CAS loop fallback for Complex types (no native atomic modify!)
5555
function _atomic_modify!(ptr::Ptr{Complex{T}}, op::OP, x::Complex{T}, ord) where {T,OP}
56-
old = _atomic_load(ptr, ord)
57-
while true
58-
new = op(old, x)
59-
result = _atomic_cas!(ptr, old, new, ord, ord)
60-
result.success && return (old => new)
61-
old = result.old
62-
end
63-
end
64-
65-
# Helper: apply atomic operation with integer reinterpretation for Complex types
66-
function _with_int_repr(f, ptr::Ptr{Complex{T}}, args...) where {T}
67-
IntType = _int_type_for_complex(T)
68-
int_ptr = reinterpret(Ptr{IntType}, ptr)
69-
result = f(int_ptr, _to_int.(IntType, args)...)
70-
return _from_int(Complex{T}, result)
56+
_generic_cas_loop_modify!(_atomic_cas!, _atomic_load, ptr, op, x, ord)
7157
end
7258

73-
_with_int_repr(f, ptr::Ptr{T}, args...) where {T} = f(ptr, args...)
74-
75-
# Integer type mapping for Complex types
76-
_int_type_for_complex(::Type{Float32}) = UInt64
77-
_int_type_for_complex(::Type{Float64}) = UInt128
78-
79-
# Convert to/from integer representation
80-
_to_int(::Type{I}, x::Complex{T}) where {I,T} = reinterpret(I, x)
81-
_to_int(::Type{I}, x) where {I} = x
82-
83-
_from_int(::Type{Complex{T}}, x::Integer) where {T} = reinterpret(Complex{T}, x)
84-
_from_int(::Type{Complex{T}}, result::NamedTuple) where {T} =
85-
(old = reinterpret(Complex{T}, result.old), success = result.success)
86-
_from_int(::Type{T}, x) where {T} = x
87-
8859
Atomix.asstorable(ref, v) = convert(eltype(ref), v)

0 commit comments

Comments
 (0)