@@ -58,27 +58,33 @@ equals_layout(::SubBasisLayouts, ::SubBasisLayouts, A::SubQuasiArray, B::SubQuas
5858equals_layout (:: MappedBasisLayouts , :: MappedBasisLayouts , A:: SubQuasiArray , B:: SubQuasiArray ) = parentindices (A) == parentindices (B) && demap (A) == demap (B)
5959equals_layout (:: AbstractWeightedBasisLayout , :: AbstractWeightedBasisLayout , A, B) = weight (A) == weight (B) && unweighted (A) == unweighted (B)
6060
61- @inline copy (L:: Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(+)}} ) = + (broadcast (\ ,Ref (L. A),arguments (L. B))... )
62- @inline copy (L:: Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(+)},<:Any,<:AbstractQuasiVector} ) =
63- transform_ldiv (L. A, L. B)
6461for op in (:+ , :- )
65- @eval @inline copy (L:: Ldiv{Lay,BroadcastLayout{typeof($op)},<:Any,<:AbstractQuasiVector} ) where Lay<: MappedBasisLayouts =
66- copy (Ldiv {Lay,LazyLayout} (L. A,L. B))
62+ @eval begin
63+ @inline copy (L:: Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof($op)}} ) = basis_broadcast_ldiv_size ($ op, size (L), L. A, L. B)
64+ @inline copy (L:: Ldiv{<:MappedBasisLayouts,BroadcastLayout{typeof($op)}} ) = copy (Ldiv {BasisLayout,BroadcastLayout{typeof($op)}} (L. A, L. B))
65+ basis_broadcast_ldiv_size (:: typeof ($ op), :: Tuple{Integer} , A, B) = transform_ldiv (A, B)
66+ end
6767end
6868
69- @inline function copy (L:: Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(-)}} )
70- a,b = arguments (L. B)
71- (L. A\ a)- (L. A\ b)
69+ basis_broadcast_ldiv_size (:: typeof (+ ), _, A, B) = + (broadcast (\ ,Ref (A),arguments (B))... )
70+
71+
72+
73+ @inline function basis_broadcast_ldiv_size (:: typeof (- ), _, A, B)
74+ a,b = arguments (B)
75+ (A\ a)- (A\ b)
7276end
7377
74- @inline copy (L:: Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(-)},<:Any,<:AbstractQuasiVector} ) =
75- transform_ldiv (L. A, L. B)
7678
79+ # TODO : remove as Not type stable
80+ simplifiable (L:: Ldiv{<:AbstractBasisLayout,<:AbstractBasisLayout} ) = Val (L. A == L. B)
7781@inline function copy (P:: Ldiv{<:AbstractBasisLayout,<:AbstractBasisLayout} )
7882 A, B = P. A, P. B
7983 A == B || throw (ArgumentError (" Override copy for $(typeof (A)) \\ $(typeof (B)) " ))
8084 SquareEye {eltype(eltype(P))} ((axes (A,2 ),)) # use double eltype for array-valued
8185end
86+
87+ simplifiable (L:: Ldiv{<:SubBasisLayouts,<:SubBasisLayouts} ) = Val (parent (L. A) == parent (L. B))
8288@inline function copy (P:: Ldiv{<:SubBasisLayouts,<:SubBasisLayouts} )
8389 A, B = P. A, P. B
8490 parent (A) == parent (B) ||
9197 demap (A)\ demap (B)
9298end
9399
94- function transform_ldiv_if_columns (P:: Ldiv{<:MappedBasisLayouts,<:Any,<:Any,<:AbstractQuasiVector} , :: OneTo )
95- A,B = P. A, P. B
96- demap (A) \ B[invmap (basismap (A))]
97- end
98-
99- function transform_ldiv_if_columns (P:: Ldiv{<:MappedBasisLayouts,<:Any,<:Any,<:AbstractQuasiMatrix} , :: OneTo )
100- A,B = P. A, P. B
101- demap (A) \ B[invmap (basismap (A)),:]
102- end
100+ copy (P:: Ldiv{<:MappedBasisLayouts} ) = mapped_ldiv_size (size (P), P. A, P. B)
101+ copy (P:: Ldiv{<:MappedBasisLayouts, <:AbstractLazyLayout} ) = mapped_ldiv_size (size (P), P. A, P. B)
102+ copy (P:: Ldiv{<:MappedBasisLayouts, <:AbstractBasisLayout} ) = mapped_ldiv_size (size (P), P. A, P. B)
103+ @inline copy (L:: Ldiv{<:MappedBasisLayouts,ApplyLayout{typeof(hcat)}} ) = mapped_ldiv_size (size (L), L. A, L. B)
104+ copy (P:: Ldiv{<:MappedBasisLayouts, ApplyLayout{typeof(*)}} ) = copy (Ldiv {BasisLayout,ApplyLayout{typeof(*)}} (P. A, P. B))
103105
104- copy (L:: Ldiv{<:MappedBasisLayouts,ApplyLayout{typeof(*)}} ) = copy (Ldiv {UnknownLayout,ApplyLayout{typeof(*)}} (L. A,L. B))
105- copy (L:: Ldiv{<:MappedBasisLayouts,ApplyLayout{typeof(*)},<:Any,<:AbstractQuasiVector} ) = transform_ldiv (L. A, L. B)
106+ mapped_ldiv_size (:: Tuple{Integer} , A, B) = demap (A) \ B[invmap (basismap (A))]
107+ mapped_ldiv_size (:: Tuple{Integer,Int} , A, B) = demap (A) \ B[invmap (basismap (A)),:]
108+ mapped_ldiv_size (:: Tuple{Integer,Any} , A, B) = copy (Ldiv {BasisLayout,typeof(MemoryLayout(B))} (A, B))
106109
107- @inline copy (L:: Ldiv{<:AbstractBasisLayout,<:SubBasisLayouts} ) = apply (\ , L. A, ApplyQuasiArray (L. B))
110+ # following allows us to use simplification
111+ @inline copy (L:: Ldiv{Lay,<:SubBasisLayouts} ) where Lay<: AbstractBasisLayout = copy (Ldiv {Lay,ApplyLayout{typeof(*)}} (L. A, L. B))
108112@inline function copy (L:: Ldiv{<:SubBasisLayouts,<:AbstractBasisLayout} )
109113 P = parent (L. A)
110114 kr, jr = parentindices (L. A)
@@ -146,11 +150,7 @@ _broadcast_mul_ldiv(::Tuple{ScalarLayout,AbstractBasisLayout}, A, B) =
146150_broadcast_mul_ldiv (_, A, B) = copy (Ldiv {typeof(MemoryLayout(A)),UnknownLayout} (A,B))
147151
148152copy (L:: Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(*)}} ) = _broadcast_mul_ldiv (map (MemoryLayout,arguments (L. B)), L. A, L. B)
149- copy (L:: Ldiv{<:AbstractBasisLayout,BroadcastLayout{typeof(*)},<:Any,<:AbstractQuasiVector} ) = _broadcast_mul_ldiv (map (MemoryLayout,arguments (L. B)), L. A, L. B)
150-
151- # ambiguity
152153copy (L:: Ldiv{<:MappedBasisLayouts,BroadcastLayout{typeof(*)}} ) = _broadcast_mul_ldiv (map (MemoryLayout,arguments (L. B)), L. A, L. B)
153- copy (L:: Ldiv{<:MappedBasisLayouts,BroadcastLayout{typeof(*)},<:Any,<:AbstractQuasiVector} ) = _broadcast_mul_ldiv (map (MemoryLayout,arguments (L. B)), L. A, L. B)
154154
155155
156156# expansion
257257plan_ldiv (A, B:: AbstractQuasiVector ) = factorize (A)
258258plan_ldiv (A, B:: AbstractQuasiMatrix ) = factorize (A, size (B,2 ))
259259
260- transform_ldiv ( A:: AbstractQuasiArray{T} , B:: AbstractQuasiArray{V} , _ ) where {T,V} = plan_ldiv (A, B) \ B
261- transform_ldiv (A, B) = transform_ldiv (A, B, size (A) )
260+ transform_ldiv_size (_, A:: AbstractQuasiArray{T} , B:: AbstractQuasiArray{V} ) where {T,V} = plan_ldiv (A, B) \ B
261+ transform_ldiv (A, B) = transform_ldiv_size ( size (A), A, B )
262262
263263
264264"""
@@ -291,28 +291,29 @@ in that basis.
291291"""
292292function expand (v)
293293 P = basis (v)
294- ApplyQuasiArray (* , P, P \ v)
294+ ApplyQuasiArray (* , P, tocoefficients ( P \ v) )
295295end
296296
297297
298298
299- copy (L:: Ldiv{<:AbstractBasisLayout} ) = transform_ldiv (L. A, L. B)
300- # TODO : redesign to use simplifiable(\, A, B)
301- copy (L:: Ldiv{<:AbstractBasisLayout,ApplyLayout{typeof(*)},<:Any,<:AbstractQuasiVector} ) = transform_ldiv (L. A, L. B)
302- copy (L:: Ldiv{<:AbstractBasisLayout,ApplyLayout{typeof(*)}} ) = copy (Ldiv {UnknownLayout,ApplyLayout{typeof(*)}} (L. A, L. B))
303- # A BroadcastLayout of unknown function is only knowable pointwise
304- transform_ldiv_if_columns (L, _) = ApplyQuasiArray (\ , L. A, L. B)
305- transform_ldiv_if_columns (L, :: OneTo ) = transform_ldiv (L. A,L. B)
306- transform_ldiv_if_columns (L) = transform_ldiv_if_columns (L, axes (L. B,2 ))
307- copy (L:: Ldiv{<:AbstractBasisLayout,<:BroadcastLayout} ) = transform_ldiv_if_columns (L)
308- # Inclusion are QuasiArrayLayout
309- copy (L:: Ldiv{<:AbstractBasisLayout,QuasiArrayLayout} ) = transform_ldiv (L. A, L. B)
310- # Otherwise keep lazy to support, e.g., U\D*T
311- copy (L:: Ldiv{<:AbstractBasisLayout,<:AbstractLazyLayout} ) = transform_ldiv_if_columns (L)
312- copy (L:: Ldiv{<:AbstractBasisLayout,ZerosLayout} ) = Zeros {eltype(L)} (axes (L)... )
313299
314- transform_ldiv_if_columns (L:: Ldiv{<:Any,<:ApplyLayout{typeof(hcat)}} , :: OneTo ) = transform_ldiv (L. A, L. B)
315- transform_ldiv_if_columns (L:: Ldiv{<:Any,<:ApplyLayout{typeof(hcat)}} , _) = hcat ((Ref (L. A) .\ arguments (hcat, L. B)). .. )
300+
301+ @inline copy (L:: Ldiv{<:AbstractBasisLayout} ) = basis_ldiv_size (size (L), L. A, L. B)
302+ @inline copy (L:: Ldiv{<:AbstractBasisLayout,<:AbstractLazyLayout} ) = basis_ldiv_size (size (L), L. A, L. B)
303+ @inline function copy (L:: Ldiv{<:AbstractBasisLayout,ApplyLayout{typeof(*)}} )
304+ simplifiable (\ , L. A, first (arguments (* , L. B))) isa Val{true } && return copy (Ldiv {UnknownLayout,ApplyLayout{typeof(*)}} (L. A, L. B))
305+ basis_ldiv_size (size (L), L. A, L. B)
306+ end
307+ @inline copy (L:: Ldiv{<:AbstractBasisLayout,ZerosLayout} ) = Zeros {eltype(L)} (axes (L)... )
308+
309+ @inline basis_ldiv_size (_, A, B) = copy (Ldiv {UnknownLayout,typeof(MemoryLayout(B))} (A, B))
310+ @inline basis_ldiv_size (:: Tuple{Integer} , A, B) = transform_ldiv (A, B)
311+ @inline basis_ldiv_size (:: Tuple{Integer,Int} , A, B) = transform_ldiv (A, B)
312+
313+ @inline copy (L:: Ldiv{<:AbstractBasisLayout,ApplyLayout{typeof(hcat)}} ) = basis_hcat_ldiv_size (size (L), L. A, L. B)
314+ @inline basis_hcat_ldiv_size (:: Tuple{Integer,Int} , A, B) = transform_ldiv (A, B)
315+ @inline basis_hcat_ldiv_size (_, A, B) = hcat ((Ref (A) .\ arguments (hcat, B)). .. )
316+
316317
317318"""
318319 WeightedFactorization(w, F)
@@ -334,7 +335,14 @@ _factorize(::WeightedBasisLayouts, wS, dims...; kws...) = WeightedFactorization(
334335# #
335336
336337struct ExpansionLayout{Lay} <: AbstractLazyLayout end
337- applylayout (:: Type{typeof(*)} , :: Lay , :: Union{PaddedLayout,AbstractStridedLayout,ZerosLayout} ) where Lay <: AbstractBasisLayout = ExpansionLayout {Lay} ()
338+ const CoefficientLayouts = Union{PaddedLayout,AbstractStridedLayout,ZerosLayout}
339+ applylayout (:: Type{typeof(*)} , :: Lay , :: CoefficientLayouts ) where Lay <: AbstractBasisLayout = ExpansionLayout {Lay} ()
340+
341+ tocoefficients (v) = tocoefficients_layout (MemoryLayout (v), v)
342+ tocoefficients_layout (:: CoefficientLayouts , v) = v
343+ tocoefficients_layout (_, v) = tocoefficients_size (size (v), v)
344+ tocoefficients_size (:: NTuple{N,Int} , v) where N = Array (v)
345+ tocoefficients_size (_, v) = v # the default is to leave it, even though we aren't technically making an ExpansionLayout
338346
339347"""
340348 basis(v)
@@ -359,7 +367,8 @@ function unweighted(lay::ExpansionLayout, a)
359367end
360368
361369LazyArrays. _mul_arguments (:: ExpansionLayout , A) = LazyArrays. _mul_arguments (ApplyLayout {typeof(*)} (), A)
362- copy (L:: Ldiv{Bas,<:ExpansionLayout} ) where Bas<: AbstractBasisLayout = copy (Ldiv {Bas,ApplyLayout{typeof(*)}} (L. A, L. B))
370+ copy (L:: Ldiv{Lay,<:ExpansionLayout} ) where Lay<: AbstractBasisLayout = copy (Ldiv {Lay,ApplyLayout{typeof(*)}} (L. A, L. B))
371+ copy (L:: Ldiv{Lay,<:ExpansionLayout} ) where Lay<: MappedBasisLayouts = copy (Ldiv {Lay,ApplyLayout{typeof(*)}} (L. A, L. B))
363372copy (L:: Mul{<:ExpansionLayout,Lay} ) where Lay = copy (Mul {ApplyLayout{typeof(*)},Lay} (L. A, L. B))
364373copy (L:: Mul{<:ExpansionLayout,Lay} ) where Lay<: AbstractLazyLayout = copy (Mul {ApplyLayout{typeof(*)},Lay} (L. A, L. B))
365374
0 commit comments