@@ -65,6 +65,14 @@ _contiguous_axis(::Any, ::Nothing) = nothing
6565 Expr (:call , Expr (:curly , :Contiguous , new_contig))
6666end
6767
68+ # contiguous_if_one(::Contiguous{1}) = Contiguous{1}()
69+ # contiguous_if_one(::Any) = Contiguous{-1}()
70+ function contiguous_axis (:: Type{R} ) where {T, N, S, A <: Array{S} , R <: Base.ReinterpretArray{T, N, S, A} }
71+ isbitstype (S) ? Contiguous {1} () : nothing
72+ # contiguous_if_one(contiguous_axis(parent_type(R)))
73+ end
74+
75+
6876"""
6977contiguous_axis_indicator(::Type{T}) -> Tuple{Vararg{<:Val}}
7078
@@ -108,6 +116,8 @@ _contiguous_batch_size(::Any, ::Any, ::Any) = nothing
108116 end
109117end
110118
119+ contiguous_batch_size (:: Type{R} ) where {T, N, S, A <: Array{S} , R <: Base.ReinterpretArray{T, N, S, A} } = ContiguousBatch {0} ()
120+
111121struct StrideRank{R} end
112122Base. @pure StrideRank (R:: NTuple{<:Any,Int} ) = StrideRank {R} ()
113123_get (:: StrideRank{R} ) where {R} = R
@@ -164,6 +174,7 @@ _stride_rank(::Any, ::Any) = nothing
164174 Expr (:call , Expr (:curly , :StrideRank , ranktup))
165175end
166176stride_rank (x, i) = stride_rank (x)[i]
177+ stride_rank (:: Type{R} ) where {T, N, S, A <: Array{S} , R <: Base.ReinterpretArray{T, N, S, A} } = StrideRank {ntuple(identity, Val{N}())} ()
167178
168179"""
169180is_column_major(A) -> Val{true/false}()
@@ -248,6 +259,16 @@ julia> A = rand(3,4);
248259
249260julia> ArrayInterface.strides(A)
250261(StaticInt{1}(), 3)
262+
263+ Additionally, the behavior differs from `Base.strides` for adjoint vectors:
264+
265+ julia> x = rand(5);
266+
267+ julia> ArrayInterface.strides(x')
268+ (StaticInt{1}(), StaticInt{1}())
269+
270+ This is to support the pattern of using just the first stride for linear indexing, `x[i]`,
271+ while still producing correct behavior when using valid cartesian indices, such as `x[1,i]`.
251272```
252273"""
253274strides (A) = Base. strides (A)
@@ -264,6 +285,16 @@ offsets(::Any) = (StaticInt{1}(),) # Assume arbitrary Julia data structures use
264285@inline strides (A:: Vector{<:Any} ) = (StaticInt (1 ),)
265286@inline strides (A:: Array{<:Any,N} ) where {N} = (StaticInt (1 ), Base. tail (Base. strides (A))... )
266287@inline strides (A:: AbstractArray ) = _strides (A, Base. strides (A), contiguous_axis (A))
288+
289+ @inline function strides (x:: LinearAlgebra.Adjoint{T,V} ) where {T, V <: AbstractVector{T} }
290+ strd = stride (parent (x), One ())
291+ (strd, strd)
292+ end
293+ @inline function strides (x:: LinearAlgebra.Transpose{T,V} ) where {T, V <: AbstractVector{T} }
294+ strd = stride (parent (x), One ())
295+ (strd, strd)
296+ end
297+
267298@generated function _strides (A:: AbstractArray{T,N} , s:: NTuple{N} , :: Contiguous{C} ) where {T,N,C}
268299 if C ≤ 0 || C > N
269300 return Expr (:block , Expr (:meta ,:inline ), :s )
@@ -282,6 +313,22 @@ offsets(::Any) = (StaticInt{1}(),) # Assume arbitrary Julia data structures use
282313 end
283314end
284315
316+ if VERSION ≥ v " 1.6.0-DEV.1581"
317+ @generated function _strides (_:: Base.ReinterpretArray{T, N, S, A, true} , s:: NTuple{N} , :: Contiguous{1} ) where {T, N, S, D, A <: Array{S,D} }
318+ stup = Expr (:tuple , :(One ()))
319+ if D < N
320+ push! (stup. args, Expr (:call , Expr (:curly , :StaticInt , sizeof (S) ÷ sizeof (T))))
321+ end
322+ for n ∈ 2 + (D < N): N
323+ push! (stup. args, Expr (:ref , :s , n))
324+ end
325+ quote
326+ $ (Expr (:meta ,:inline ))
327+ @inbounds $ stup
328+ end
329+ end
330+ end
331+
285332@inline function offsets (x, i)
286333 inds = indices (x, i)
287334 start = known_first (inds)
304351@inline strides (B:: PermutedDimsArray{T,N,I1,I2,A} ) where {T,N,I1,I2,A<: AbstractArray{T,N} } = permute (strides (parent (B)), Val {I1} ())
305352@inline stride (A:: AbstractArray , :: StaticInt{N} ) where {N} = strides (A)[N]
306353@inline stride (A:: AbstractArray , :: Val{N} ) where {N} = strides (A)[N]
307- stride (A, i) = Base. stride (A, i)
354+ stride (A, i) = Base. stride (A, i) # for type stability
308355
309356size (B:: S ) where {N,NP,T,A<: AbstractArray{T,NP} ,I,S <: SubArray{T,N,A,I} } = _size (size (parent (B)), B. indices, map (static_length, B. indices))
310357strides (B:: S ) where {N,NP,T,A<: AbstractArray{T,NP} ,I,S <: SubArray{T,N,A,I} } = _strides (strides (parent (B)), B. indices)
324371@generated function _strides (A:: Tuple{Vararg{Any,N}} , inds:: I ) where {N, I<: Tuple }
325372 t = Expr (:tuple )
326373 for n in 1 : N
327- if I. parameters[n] <: AbstractRange
374+ if I. parameters[n] <: AbstractUnitRange
328375 push! (t. args, Expr (:ref , :A , n))
376+ elseif I. parameters[n] <: AbstractRange
377+ push! (t. args, Expr (:call , :(* ), Expr (:ref , :A , n), Expr (:call , :static_step , Expr (:ref , :inds , n))))
329378 elseif ! (I. parameters[n] <: Integer )
330379 return nothing
331380 end
0 commit comments