@@ -34,39 +34,41 @@ def roll(
3434 right_part = data .narrow (dim_index , data .size (dims ) - shift , shift )
3535 return torch .cat ([right_part , left_part ], dim = dim_index )
3636
37- def fftshift (data : torch .Tensor ) -> torch .Tensor :
38- dim = (1 , 2 )
39- shift = [data .size (curr_dim ) // 2 for curr_dim in dim ]
40- return roll (data , shift , dim )
37+ def fftshift (data : torch .Tensor , dims ) -> torch .Tensor :
38+ shift = [data .size (curr_dim ) // 2 for curr_dim in dims ]
39+ return roll (data , shift , dims )
4140
42- def ifftshift (data : torch .Tensor ) -> torch .Tensor :
43- dim = (1 , 2 )
44- shift = [(data .size (curr_dim ) + 1 ) // 2 for curr_dim in dim ]
45- return roll (data , shift , dim )
41+ def ifftshift (data : torch .Tensor , dims ) -> torch .Tensor :
42+ shift = [(data .size (curr_dim ) + 1 ) // 2 for curr_dim in dims ]
43+ return roll (data , shift , dims )
4644
4745class FFT (torch .autograd .Function ):
4846 @staticmethod
49- def symbolic (g , x , inverse , centered = False ):
50- return g .op ('IFFT' if inverse else 'FFT' , x ,
47+ def symbolic (g , x , inverse , centered , dims ):
48+ dims = torch .tensor (dims )
49+ dims = g .op ("Constant" , value_t = dims )
50+
51+ return g .op ('IFFT' if inverse else 'FFT' , x , dims ,
5152 inverse_i = inverse , centered_i = centered )
5253
5354 @staticmethod
54- def forward (self , x , inverse , centered = False ):
55+ def forward (self , x , inverse , centered , dims ):
5556 # https://pytorch.org/docs/stable/torch.html#torch.fft
56- signal_ndim = 2 if len (x .shape ) == 5 else 1
5757 if centered :
58- x = ifftshift (x )
58+ x = ifftshift (x , dims )
5959
6060 if version .parse (torch .__version__ ) >= version .parse ("1.8.0" ):
6161 func = torch .fft .ifftn if inverse else torch .fft .fftn
6262 x = torch .view_as_complex (x )
63- y = func (x , dim = list ( range ( 1 , signal_ndim + 1 )) , norm = "ortho" )
63+ y = func (x , dim = dims , norm = "ortho" )
6464 y = torch .view_as_real (y )
6565 else :
66+ signal_ndim = max (dims )
67+ assert dims == list (range (1 , signal_ndim + 1 ))
6668 func = torch .ifft if inverse else torch .fft
6769 y = func (input = x , signal_ndim = signal_ndim , normalized = True )
6870
6971 if centered :
70- y = fftshift (y )
72+ y = fftshift (y , dims )
7173
7274 return y
0 commit comments