Skip to content

Commit 0f299dc

Browse files
committed
Fixes
1 parent 3a74c59 commit 0f299dc

File tree

5 files changed

+35
-28
lines changed

5 files changed

+35
-28
lines changed

lib/mkl/array.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ const oneAbstractSparseVector{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 1}
55
const oneAbstractSparseMatrix{Tv, Ti} = oneAbstractSparseArray{Tv, Ti, 2}
66

77
mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
8-
handle::matrix_handle_t
8+
handle::Union{Nothing, matrix_handle_t}
99
rowPtr::oneVector{Ti}
1010
colVal::oneVector{Ti}
1111
nzVal::oneVector{Tv}
@@ -14,7 +14,7 @@ mutable struct oneSparseMatrixCSR{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
1414
end
1515

1616
mutable struct oneSparseMatrixCSC{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
17-
handle::matrix_handle_t
17+
handle::Union{Nothing, matrix_handle_t}
1818
colPtr::oneVector{Ti}
1919
rowVal::oneVector{Ti}
2020
nzVal::oneVector{Tv}
@@ -23,7 +23,7 @@ mutable struct oneSparseMatrixCSC{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
2323
end
2424

2525
mutable struct oneSparseMatrixCOO{Tv, Ti} <: oneAbstractSparseMatrix{Tv, Ti}
26-
handle::matrix_handle_t
26+
handle::Union{Nothing, matrix_handle_t}
2727
rowInd::oneVector{Ti}
2828
colInd::oneVector{Ti}
2929
nzVal::oneVector{Tv}

lib/mkl/wrappers_sparse.jl

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function sparse_release_matrix_handle(A::oneAbstractSparseMatrix)
2-
queue = global_queue(context(A.nzVal), device(A.nzVal))
3-
m, n = size(A)
4-
return if m != 0 && n != 0
2+
return if A.handle !== nothing
3+
queue = global_queue(context(A.nzVal), device(A.nzVal))
4+
oneL0.synchronize(queue)
55
handle_ptr = Ref{matrix_handle_t}(A.handle)
66
onemklXsparse_release_matrix_handle(sycl_queue(queue), handle_ptr)
77
end
@@ -67,8 +67,6 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
6767
if m != 0 && n != 0
6868
$fname(sycl_queue(queue), handle_ptr[], m, n, 'O', rowPtr, colVal, nzVal)
6969
end
70-
dA = oneSparseMatrixCSR{$elty, $intty}(handle_ptr[], rowPtr, colVal, nzVal, (m, n), nnzA)
71-
finalizer(sparse_release_matrix_handle, dA)
7270
return dA
7371
end
7472

@@ -84,9 +82,11 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
8482
# Don't update handle if matrix is empty
8583
if m != 0 && n != 0
8684
$fname(sycl_queue(queue), handle_ptr[], n, m, 'O', colPtr, rowVal, nzVal) # CSC of A is CSR of Aᵀ
85+
dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, dims, nnzA)
86+
finalizer(sparse_release_matrix_handle, dA)
87+
else
88+
dA = oneSparseMatrixCSC{$elty, $intty}(nothing, colPtr, rowVal, nzVal, dims, nnzA)
8789
end
88-
dA = oneSparseMatrixCSC{$elty, $intty}(handle_ptr[], colPtr, rowVal, nzVal, dims, nnzA)
89-
finalizer(sparse_release_matrix_handle, dA)
9090
return dA
9191
end
9292

@@ -100,7 +100,6 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
100100
end
101101

102102
function SparseMatrixCSC(A::oneSparseMatrixCSR{$elty, $intty})
103-
handle_ptr = Ref{matrix_handle_t}()
104103
At = SparseMatrixCSC(reverse(A.dims)..., Vector(A.rowPtr), Vector(A.colVal), Vector(A.nzVal))
105104
A_csc = SparseMatrixCSC(At |> transpose)
106105
return A_csc
@@ -115,7 +114,6 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
115114
end
116115

117116
function SparseMatrixCSC(A::oneSparseMatrixCSC{$elty, $intty})
118-
handle_ptr = Ref{matrix_handle_t}()
119117
A_csc = SparseMatrixCSC(A.dims..., Vector(A.colPtr), Vector(A.rowVal), Vector(A.nzVal))
120118
return A_csc
121119
end
@@ -141,14 +139,17 @@ for (fname, elty, intty) in ((:onemklSsparse_set_coo_data , :Float32 , :Int3
141139
nzVal = oneVector{$elty}(val)
142140
nnzA = length(val)
143141
queue = global_queue(context(nzVal), device(nzVal))
144-
$fname(sycl_queue(queue), handle_ptr[], m, n, nnzA, 'O', rowInd, colInd, nzVal)
145-
dA = oneSparseMatrixCOO{$elty, $intty}(handle_ptr[], rowInd, colInd, nzVal, (m,n), nnzA)
146-
finalizer(sparse_release_matrix_handle, dA)
142+
if m != 0 && n != 0
143+
$fname(sycl_queue(queue), handle_ptr[], m, n, nnzA, 'O', rowInd, colInd, nzVal)
144+
dA = oneSparseMatrixCOO{$elty, $intty}(handle_ptr[], rowInd, colInd, nzVal, (m,n), nnzA)
145+
finalizer(sparse_release_matrix_handle, dA)
146+
else
147+
dA = oneSparseMatrixCOO{$elty, $intty}(nothing, rowInd, colInd, nzVal, (m,n), nnzA)
148+
end
147149
return dA
148150
end
149151

150152
function SparseMatrixCSC(A::oneSparseMatrixCOO{$elty, $intty})
151-
handle_ptr = Ref{matrix_handle_t}()
152153
A = sparse(Vector(A.rowInd), Vector(A.colInd), Vector(A.nzVal), A.dims...)
153154
return A
154155
end

src/indexing.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,17 @@ function Base.findall(bools::oneArray{Bool})
2020
I = keytype(bools)
2121

2222
indices = cumsum(reshape(bools, prod(size(bools))))
23-
oneL0.synchronize()
2423

2524
n = isempty(indices) ? 0 : @allowscalar indices[end]
2625

2726
ys = oneArray{I}(undef, n)
2827

2928
if n > 0
30-
@oneapi items = length(bools) _ker!(ys, bools, indices)
29+
kernel = @oneapi launch=false _ker!(ys, bools, indices)
30+
group_size = launch_configuration(kernel)
31+
kernel(ys, bools, indices; items=group_size, groups=cld(length(bools), group_size))
3132
end
32-
oneL0.synchronize()
33-
unsafe_free!(indices)
33+
# unsafe_free!(indices)
3434

3535
return ys
3636
end

test/indexing.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,18 @@ using oneAPI
1717
data = oneArray(collect(1:6))
1818
mask = oneArray(Bool[true, false, true, false, false, true])
1919
@test Array(data[mask]) == collect(1:6)[findall(Bool[true, false, true, false, false, true])]
20+
21+
# Test with array larger than 1024 to trigger multiple groups
22+
large_size = 2048
23+
large_mask = oneArray(rand(Bool, large_size))
24+
large_result_gpu = Array(findall(large_mask))
25+
large_result_cpu = findall(Array(large_mask))
26+
@test large_result_gpu == large_result_cpu
27+
28+
# Test with even larger array to ensure robustness
29+
very_large_size = 5000
30+
very_large_mask = oneArray(fill(true, very_large_size)) # all true for predictable result
31+
very_large_result_gpu = Array(findall(very_large_mask))
32+
very_large_result_cpu = findall(fill(true, very_large_size))
33+
@test very_large_result_gpu == very_large_result_cpu
2034
end

test/onemkl.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,11 +1093,7 @@ end
10931093
C = oneSparseMatrixCSR(B.rowPtr, B.colVal, B.nzVal, size(B))
10941094
A3 = SparseMatrixCSC(C)
10951095
@test A == A3
1096-
<<<<<<< HEAD
10971096
D = oneSparseMatrixCSR(oneVector(S[]), oneVector(S[]), oneVector(T[]), (0, 0)) # empty matrix
1098-
=======
1099-
D = oneSparseMatrixCSR(S[], S[], T[], (0, 0)) # empty matrix
1100-
>>>>>>> 1a9e6ed (Format)
11011097
end
11021098
end
11031099

@@ -1112,11 +1108,7 @@ end
11121108
C = oneSparseMatrixCSC(A.colptr |> oneVector, A.rowval |> oneVector, A.nzval |> oneVector, size(A))
11131109
A3 = SparseMatrixCSC(C)
11141110
@test A == A3
1115-
<<<<<<< HEAD
11161111
D = oneSparseMatrixCSC(oneVector(S[]), oneVector(S[]), oneVector(T[]), (0, 0)) # empty matrix
1117-
=======
1118-
D = oneSparseMatrixCSC(S[], S[], T[], (0, 0)) # empty matrix
1119-
>>>>>>> 1a9e6ed (Format)
11201112
end
11211113
end
11221114

0 commit comments

Comments
 (0)