@@ -11,34 +11,33 @@ using LinearAlgebra
1111using Test
1212
1313"""
14- test_fft_backend(array_constructor ; test_real=true, test_inplace=true)
14+ test_fft_backend(ArrayType=Array ; test_real=true, test_inplace=true)
1515
1616Run tests to verify correctness of FFT functions using a particular
1717backend plan implementation. The backend implementation is assumed to be loaded
1818prior to calling this function.
1919
2020# Arguments
2121
22- - `array_constructor`: determines the `AbstractArray` implementation for
23- which the correctness tests are run. It is assumed to be a callable object that
24- takes in input arrays of type `Array` and return arrays of the desired type for
25- testing. For example, this can be a constructor such as `Array` or `CUDA.CuArray`.
22+ - `ArrayType`: determines the `AbstractArray` implementation for
23+ which the correctness tests are run. Arrays are constructed via
24+ `convert(ArrayType, ...)`.
2625- `test_real=true`: whether to test real-to-complex and complex-to-real FFTs.
2726- `test_inplace=true`: whether to test in-place plans.
2827"""
29- function test_fft_backend (array_constructor ; test_real= true , test_inplace= true )
28+ function test_fft_backend (ArrayType = Array ; test_real= true , test_inplace= true )
3029 @testset " fft correctness" begin
3130 # DFT along last dimension, results computed using FFTW
32- for (_x, dims, real_input, _fftw_fft) in (
33- (collect (1 : 7 ), 1 , true ,
31+ for (_x, dims, _fftw_fft) in (
32+ (collect (1 : 7 ), 1 ,
3433 [28.0 + 0.0im ,
3534 - 3.5 + 7.267824888003178im ,
3635 - 3.5 + 2.7911568610884143im ,
3736 - 3.5 + 0.7988521603655248im ,
3837 - 3.5 - 0.7988521603655248im ,
3938 - 3.5 - 2.7911568610884143im ,
4039 - 3.5 - 7.267824888003178im ]),
41- (collect (1 : 8 ), 1 , true ,
40+ (collect (1 : 8 ), 1 ,
4241 [36.0 + 0.0im ,
4342 - 4.0 + 9.65685424949238im ,
4443 - 4.0 + 4.0im ,
@@ -47,49 +46,50 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
4746 - 4.0 - 1.6568542494923806im ,
4847 - 4.0 - 4.0im ,
4948 - 4.0 - 9.65685424949238im ]),
50- (collect (reshape (1 : 8 , 2 , 4 )), 2 , true ,
49+ (collect (reshape (1 : 8 , 2 , 4 )), 2 ,
5150 [16.0 + 0.0im - 4.0 + 4.0im - 4.0 + 0.0im - 4.0 - 4.0im ;
5251 20.0 + 0.0im - 4.0 + 4.0im - 4.0 + 0.0im - 4.0 - 4.0im ]),
53- (collect (reshape (1 : 9 , 3 , 3 )), 2 , true ,
52+ (collect (reshape (1 : 9 , 3 , 3 )), 2 ,
5453 [12.0 + 0.0im - 4.5 + 2.598076211353316im - 4.5 - 2.598076211353316im ;
5554 15.0 + 0.0im - 4.5 + 2.598076211353316im - 4.5 - 2.598076211353316im ;
5655 18.0 + 0.0im - 4.5 + 2.598076211353316im - 4.5 - 2.598076211353316im ]),
57- (collect (reshape (1 : 8 , 2 , 2 , 2 )), 1 : 2 , true ,
56+ (collect (reshape (1 : 8 , 2 , 2 , 2 )), 1 : 2 ,
5857 cat ([10.0 + 0.0im - 4.0 + 0.0im ; - 2.0 + 0.0im 0.0 + 0.0im ],
5958 [26.0 + 0.0im - 4.0 + 0.0im ; - 2.0 + 0.0im 0.0 + 0.0im ],
6059 dims= 3 )),
61- (collect (1 : 7 ) + im * collect (8 : 14 ), 1 , false ,
60+ (collect (1 : 7 ) + im * collect (8 : 14 ), 1 ,
6261 [28.0 + 77.0im ,
6362 - 10.76782488800318 + 3.767824888003175im ,
6463 - 6.291156861088416 - 0.7088431389115883im ,
6564 - 4.298852160365525 - 2.7011478396344746im ,
6665 - 2.7011478396344764 - 4.298852160365524im ,
6766 - 0.7088431389115866 - 6.291156861088417im ,
6867 3.767824888003177 - 10.76782488800318im ]),
69- (collect (reshape (1 : 8 , 2 , 2 , 2 )) + im * reshape (9 : 16 , 2 , 2 , 2 ), 1 : 2 , false ,
68+ (collect (reshape (1 : 8 , 2 , 2 , 2 )) + im * reshape (9 : 16 , 2 , 2 , 2 ), 1 : 2 ,
7069 cat ([10.0 + 42.0im - 4.0 - 4.0im ; - 2.0 - 2.0im 0.0 + 0.0im ],
7170 [26.0 + 58.0im - 4.0 - 4.0im ; - 2.0 - 2.0im 0.0 + 0.0im ],
7271 dims= 3 )),
7372 )
74- x = array_constructor (_x) # dummy array that will be passed to plans
75- x_complex = complex .(float .(x)) # for testing complex FFTs
76- fftw_fft = array_constructor (_fftw_fft)
73+ x = convert (ArrayType, _x) # dummy array that will be passed to plans
74+ x_complex = convert (ArrayType, complex .(x)) # for testing complex FFTs
75+ x_complexfloat = convert (ArrayType, complex .(float .(x))) # for in-place operations
76+ fftw_fft = convert (ArrayType, _fftw_fft)
7777
7878 # FFT
7979 y = AbstractFFTs. fft (x_complex, dims)
8080 @test y ≈ fftw_fft
8181 if test_inplace
82- @test AbstractFFTs. fft! (copy (x_complex ), dims) ≈ fftw_fft
82+ @test AbstractFFTs. fft! (copy (x_complexfloat ), dims) ≈ fftw_fft
8383 end
8484 # test plan_fft and also inv and plan_inv of plan_ifft, which should all give
8585 # functionally identical plans
8686 plans_to_test = [plan_fft (x, dims), inv (plan_ifft (x, dims)),
8787 AbstractFFTs. plan_inv (plan_ifft (x, dims))]
8888 for P in plans_to_test
89- @test mul! (similar (y), P, copy (x_complex )) ≈ fftw_fft
89+ @test mul! (similar (y), P, copy (x_complexfloat )) ≈ fftw_fft
9090 end
9191 if test_inplace
92- push! (plans_to_test, plan_fft! (similar (x_complex ), dims))
92+ push! (plans_to_test, plan_fft! (similar (x_complexfloat ), dims))
9393 end
9494 for P in plans_to_test
9595 @test eltype (P) <: Complex
@@ -105,7 +105,7 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
105105 @test AbstractFFTs. bfft! (copy (y), dims) ≈ fftw_bfft
106106 end
107107 P = plan_bfft (similar (y), dims)
108- @test mul! (similar (x_complex ), P, copy (y)) ≈ fftw_bfft
108+ @test mul! (similar (x_complexfloat ), P, copy (y)) ≈ fftw_bfft
109109 plans_to_test = if test_inplace
110110 [P, plan_bfft! (similar (y), dims)]
111111 else
@@ -127,10 +127,10 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
127127 plans_to_test = [plan_ifft (x, dims), inv (plan_fft (x, dims)),
128128 AbstractFFTs. plan_inv (plan_fft (x, dims))]
129129 for P in plans_to_test
130- @test mul! (similar (x_complex ), P, copy (y)) ≈ fftw_ifft
130+ @test mul! (similar (x_complexfloat ), P, copy (y)) ≈ fftw_ifft
131131 end
132132 if test_inplace
133- push! (plans_to_test, plan_ifft! (similar (x_complex ), dims))
133+ push! (plans_to_test, plan_ifft! (similar (x_complexfloat ), dims))
134134 end
135135 for P in plans_to_test
136136 @test eltype (P) <: Complex
@@ -139,7 +139,7 @@ function test_fft_backend(array_constructor; test_real=true, test_inplace=true)
139139 @test fftdims (P) == dims
140140 end
141141
142- if test_real && real_input
142+ if test_real && (x isa Real)
143143 x_real = float .(x) # for testing real FFTs
144144 # RFFT
145145 fftw_rfft = selectdim (fftw_fft, first (dims), 1 : (size (fftw_fft, first (dims)) ÷ 2 + 1 ))
0 commit comments