Skip to content
This repository was archived by the owner on May 29, 2023. It is now read-only.

Commit 721c07a

Browse files
authored
Support Dims attribute for FFT (#33)
1 parent 29f0829 commit 721c07a

File tree

12 files changed

+196
-112
lines changed

12 files changed

+196
-112
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,4 +174,4 @@ jobs:
174174
run: |
175175
python3 -m pip install --upgrade pip
176176
python3 -m pip install twine
177-
python3 -m twine upload --repository testpypi wheel*/*.whl --skip-existing
177+
python3 -m twine upload wheel*/*.whl --skip-existing

examples/fft/export_model.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,30 @@
77

88

99
class MyModel(nn.Module):
10-
def __init__(self):
10+
def __init__(self, inverse, centred, dims):
1111
super(MyModel, self).__init__()
12+
self.inverse = inverse
13+
self.centred = centred
14+
self.dims = dims
1215
self.fft = FFT()
1316

1417
def forward(self, x):
15-
y = self.fft.apply(x, False)
16-
y = y * 2
17-
# TODO: there is a bug with "inverse" data attribute in OpenVINO 2021.4
18-
y = self.fft.apply(y, True)
19-
return y
18+
return self.fft.apply(x, self.inverse, self.centred, self.dims)
2019

2120

22-
def export(shape):
21+
def export(shape, inverse, centered, dims):
2322
np.random.seed(324)
2423
torch.manual_seed(32)
2524

26-
model = MyModel()
25+
model = MyModel(inverse, centered, dims)
2726
inp = Variable(torch.randn(shape))
2827
model.eval()
2928

3029
with torch.no_grad():
3130
torch.onnx.export(model, inp, 'model.onnx',
3231
input_names=['input'],
3332
output_names=['output'],
34-
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
33+
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH)
3534

3635
ref = model(inp)
3736
np.save('inp', inp.detach().numpy())

examples/fft/export_model_with_roll.py

Lines changed: 0 additions & 44 deletions
This file was deleted.

examples/fft/fft.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,39 +34,41 @@ def roll(
3434
right_part = data.narrow(dim_index, data.size(dims) - shift, shift)
3535
return torch.cat([right_part, left_part], dim=dim_index)
3636

37-
def fftshift(data: torch.Tensor) -> torch.Tensor:
38-
dim = (1, 2)
39-
shift = [data.size(curr_dim) // 2 for curr_dim in dim]
40-
return roll(data, shift, dim)
37+
def fftshift(data: torch.Tensor, dims) -> torch.Tensor:
38+
shift = [data.size(curr_dim) // 2 for curr_dim in dims]
39+
return roll(data, shift, dims)
4140

42-
def ifftshift(data: torch.Tensor) -> torch.Tensor:
43-
dim = (1, 2)
44-
shift = [(data.size(curr_dim) + 1) // 2 for curr_dim in dim]
45-
return roll(data, shift, dim)
41+
def ifftshift(data: torch.Tensor, dims) -> torch.Tensor:
42+
shift = [(data.size(curr_dim) + 1) // 2 for curr_dim in dims]
43+
return roll(data, shift, dims)
4644

4745
class FFT(torch.autograd.Function):
4846
@staticmethod
49-
def symbolic(g, x, inverse, centered=False):
50-
return g.op('IFFT' if inverse else 'FFT', x,
47+
def symbolic(g, x, inverse, centered, dims):
48+
dims = torch.tensor(dims)
49+
dims = g.op("Constant", value_t=dims)
50+
51+
return g.op('IFFT' if inverse else 'FFT', x, dims,
5152
inverse_i=inverse, centered_i=centered)
5253

5354
@staticmethod
54-
def forward(self, x, inverse, centered=False):
55+
def forward(self, x, inverse, centered, dims):
5556
# https://pytorch.org/docs/stable/torch.html#torch.fft
56-
signal_ndim = 2 if len(x.shape) == 5 else 1
5757
if centered:
58-
x = ifftshift(x)
58+
x = ifftshift(x, dims)
5959

6060
if version.parse(torch.__version__) >= version.parse("1.8.0"):
6161
func = torch.fft.ifftn if inverse else torch.fft.fftn
6262
x = torch.view_as_complex(x)
63-
y = func(x, dim=list(range(1, signal_ndim + 1)), norm="ortho")
63+
y = func(x, dim=dims, norm="ortho")
6464
y = torch.view_as_real(y)
6565
else:
66+
signal_ndim = max(dims)
67+
assert dims == list(range(1, signal_ndim + 1))
6668
func = torch.ifft if inverse else torch.fft
6769
y = func(input=x, signal_ndim=signal_ndim, normalized=True)
6870

6971
if centered:
70-
y = fftshift(y)
72+
y = fftshift(y, dims)
7173

7274
return y

mo_extensions/front/onnx/fft_ext.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class FFTFrontExtractor(FrontExtractorOp):
1111
def extract(cls, node):
1212
data = {
1313
"inverse": onnx_attr(node, "inverse", "i"),
14+
"centered": onnx_attr(node, "centered", "i"),
1415
}
1516

1617
FFT.update_node_stat(node, data)
@@ -25,6 +26,7 @@ class IFFTFrontExtractor(FrontExtractorOp):
2526
def extract(cls, node):
2627
data = {
2728
"inverse": 1,
29+
"centered": onnx_attr(node, "centered", "i"),
2830
}
2931

3032
IFFT.update_node_stat(node, data)

mo_extensions/ops/FFT.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@ def __init__(self, graph: Graph, attrs: dict):
1313
{
1414
"type": __class__.op,
1515
"op": __class__.op,
16-
"in_ports_count": 1,
16+
"in_ports_count": 2,
1717
"out_ports_count": 1,
1818
"infer": copy_shape_infer,
1919
},
2020
attrs,
2121
)
2222

2323
def supported_attrs(self):
24-
return ["inverse"]
24+
return ["inverse", "centered"]
2525

2626

2727
class IFFT(Op):
@@ -34,12 +34,12 @@ def __init__(self, graph: Graph, attrs: dict):
3434
{
3535
"type": __class__.op,
3636
"op": __class__.op,
37-
"in_ports_count": 1,
37+
"in_ports_count": 2,
3838
"out_ports_count": 1,
3939
"infer": copy_shape_infer,
4040
},
4141
attrs,
4242
)
4343

4444
def supported_attrs(self):
45-
return ["inverse"]
45+
return ["inverse", "centered"]

tests/run_tests.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,21 @@ def test_unpool_reshape():
6060
export(mode='dynamic_size', shape=[4, 3, 17, 8])
6161
run_test(convert_ir=False)
6262

63-
@pytest.mark.parametrize("shape", [[5, 120, 2], [4, 240, 320, 2], [3, 5, 240, 320, 2]])
64-
def test_fft(shape):
63+
@pytest.mark.parametrize("shape", [[5, 120, 2], [4, 240, 320, 2], [3, 16, 240, 320, 2]])
64+
@pytest.mark.parametrize("inverse", [False, True])
65+
@pytest.mark.parametrize("centered", [False, True])
66+
@pytest.mark.parametrize("test_onnx", [False, True])
67+
@pytest.mark.parametrize("dims", [[1], [1, 2], [2, 3]])
68+
def test_fft(shape, inverse, centered, test_onnx, dims):
6569
from examples.fft.export_model import export
6670

67-
export(shape=shape)
68-
run_test()
69-
70-
@pytest.mark.parametrize("test_onnx", [True, False])
71-
def test_fft_roll(test_onnx):
72-
from examples.fft.export_model_with_roll import export
71+
if len(shape) == 3 and dims != [1] or \
72+
len(shape) == 4 and dims == [2, 3] or \
73+
len(shape) == 5 and dims == [1] or \
74+
centered and len(dims) != 2:
75+
pytest.skip("unsupported configuration")
7376

74-
export()
77+
export(shape, inverse, centered, dims)
7578
run_test(test_onnx=test_onnx)
7679

7780

user_ie_extensions/cpu_kernel.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class FFTImpl : public InferenceEngine::ILayerExecImpl {
3838
std::vector<InferenceEngine::Blob::Ptr> &outputs,
3939
InferenceEngine::ResponseDesc *resp) noexcept override;
4040
private:
41-
ngraph::Shape inpShape;
41+
std::vector<ngraph::Shape> inShapes;
4242
ngraph::Shape outShape;
4343
bool inverse, centered;
4444
std::string error;

user_ie_extensions/extension.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ Extension::Extension() {
2121
ngraph::OutputVector ng_inputs {node.get_ng_inputs()};
2222
bool inverse = node.get_attribute_value<int64_t>("inverse");
2323
bool centered = node.get_attribute_value<int64_t>("centered");
24-
return {std::make_shared<FFTOp>(ng_inputs.at(0), inverse, centered)};
24+
return {std::make_shared<FFTOp>(ng_inputs.at(0), ng_inputs.at(1), inverse, centered)};
2525
});
2626
ngraph::onnx_import::register_operator(IFFTOp::type_info.name, 1, "", [](const ngraph::onnx_import::Node& node) -> ngraph::OutputVector {
2727
ngraph::OutputVector ng_inputs {node.get_ng_inputs()};
2828
bool inverse = node.get_attribute_value<int64_t>("inverse");
2929
bool centered = node.get_attribute_value<int64_t>("centered");
30-
return {std::make_shared<IFFTOp>(ng_inputs.at(0), inverse, centered)};
30+
return {std::make_shared<IFFTOp>(ng_inputs.at(0), ng_inputs.at(1), inverse, centered)};
3131
});
3232
ngraph::onnx_import::register_operator(ComplexMulOp::type_info.name, 1, "", [](const ngraph::onnx_import::Node& node) -> ngraph::OutputVector {
3333
ngraph::OutputVector ng_inputs {node.get_ng_inputs()};

0 commit comments

Comments
 (0)