Skip to content
This repository was archived by the owner on May 29, 2023. It is now read-only.

Commit 4fd2091

Browse files
authored
Fix FFTShift for odd dimensions (#35)
1 parent 721c07a commit 4fd2091

File tree

3 files changed

+60
-11
lines changed

3 files changed

+60
-11
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ on:
1212

1313
env:
1414
OPENVINO_VERSION: 2021.4.2
15-
VERSION: 2021.4.2.4
15+
VERSION: 2021.4.2.5
1616
DIST_VERSION: 2021.4.752
1717
DIST_WIN: https://registrationcenter-download.intel.com/akdlm/irc_nas/18320/w_openvino_toolkit_p_2021.4.752.exe
1818
DIST_MAC: https://registrationcenter-download.intel.com/akdlm/irc_nas/18317/m_openvino_toolkit_p_2021.4.752.dmg

tests/run_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_unpool_reshape():
6060
export(mode='dynamic_size', shape=[4, 3, 17, 8])
6161
run_test(convert_ir=False)
6262

63-
@pytest.mark.parametrize("shape", [[5, 120, 2], [4, 240, 320, 2], [3, 16, 240, 320, 2]])
63+
@pytest.mark.parametrize("shape", [[5, 120, 2], [4, 240, 320, 2], [3, 16, 240, 320, 2], [4, 5, 16, 31, 2]])
6464
@pytest.mark.parametrize("inverse", [False, True])
6565
@pytest.mark.parametrize("centered", [False, True])
6666
@pytest.mark.parametrize("test_onnx", [False, True])

user_ie_extensions/fft_impl.cpp

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ InferenceEngine::StatusCode FFTImpl::init(InferenceEngine::LayerConfig &config,
119119
}
120120
//! [cpu_implementation:init]
121121

122-
static void fftshift(CvMat* src) {
122+
static void fftshift(CvMat* src, bool inverse) {
123123
static auto cvCloneMat = reinterpret_cast<cvCloneMatF*>(so->get_symbol("cvCloneMat"));
124124
static auto cvCopy = reinterpret_cast<cvCopyF*>(so->get_symbol("cvCopy"));
125125
static auto cvInitMatHeader = reinterpret_cast<cvInitMatHeaderF*>(so->get_symbol("cvInitMatHeader"));
@@ -141,6 +141,55 @@ static void fftshift(CvMat* src) {
141141
int h2 = height / 2;
142142
int w2 = width / 2;
143143

144+
if (height % 2 || width % 2) {
145+
// Swap rows.
146+
CvMat* srcTop = new CvMat();
147+
CvMat* srcBot = new CvMat();
148+
CvMat* dstTop = new CvMat();
149+
CvMat* dstBot = new CvMat();
150+
int topH = inverse ? h2 : (h2 + height % 2);
151+
int botH = height - topH;
152+
cvInitMatHeader(srcTop, topH, width, CV_32FC2, data, step);
153+
cvInitMatHeader(srcBot, botH, width, CV_32FC2, data + topH * width * 2, step);
154+
cvInitMatHeader(dstTop, topH, width, CV_32FC2, data + botH * width * 2, step);
155+
cvInitMatHeader(dstBot, botH, width, CV_32FC2, data, step);
156+
157+
CvMat* tmp = cvCloneMat(srcTop);
158+
cvCopy(srcBot, dstBot, 0);
159+
cvCopy(tmp, dstTop, 0);
160+
161+
cvReleaseMat(&tmp);
162+
delete srcTop;
163+
delete srcBot;
164+
delete dstTop;
165+
delete dstBot;
166+
167+
// Swap columns.
168+
CvMat* srcL = new CvMat();
169+
CvMat* srcR = new CvMat();
170+
CvMat* dstL = new CvMat();
171+
CvMat* dstR = new CvMat();
172+
int leftW = inverse ? w2 : (w2 + width % 2);
173+
int rightW = width - leftW;
174+
175+
cvInitMatHeader(srcL, height, leftW, CV_32FC2, data, step);
176+
cvInitMatHeader(srcR, height, rightW, CV_32FC2, data + leftW * 2, step);
177+
cvInitMatHeader(dstL, height, leftW, CV_32FC2, data + rightW * 2, step);
178+
cvInitMatHeader(dstR, height, rightW, CV_32FC2, data, step);
179+
180+
tmp = cvCloneMat(srcL);
181+
cvCopy(srcR, dstR, 0);
182+
cvCopy(tmp, dstL, 0);
183+
184+
cvReleaseMat(&tmp);
185+
delete srcL;
186+
delete srcR;
187+
delete dstL;
188+
delete dstR;
189+
190+
return;
191+
}
192+
144193
CvMat* tl = new CvMat();
145194
CvMat* tr = new CvMat();
146195
CvMat* bl = new CvMat();
@@ -222,7 +271,7 @@ InferenceEngine::StatusCode FFTImpl::execute(std::vector<InferenceEngine::Blob::
222271
cvCreateData(out);
223272

224273
if (centered)
225-
fftshift(inp);
274+
fftshift(inp, true);
226275

227276
if (inverse)
228277
cvDFT(inp, out, CV_DXT_INVERSE, 0);
@@ -231,7 +280,7 @@ InferenceEngine::StatusCode FFTImpl::execute(std::vector<InferenceEngine::Blob::
231280
cvScale(out, out, 1.0 / sqrtf(channels * rows), 0);
232281

233282
if (centered)
234-
fftshift(out);
283+
fftshift(out, false);
235284

236285
CvMat out_col_header, *out_col;
237286
out_col = cvReshape(out, &out_col_header, 2, channels * rows);
@@ -259,7 +308,7 @@ InferenceEngine::StatusCode FFTImpl::execute(std::vector<InferenceEngine::Blob::
259308
cvSetData(out, reinterpret_cast<void*>(outData + d * planeSize), cols * 2 * sizeof(float));
260309

261310
if (centered)
262-
fftshift(inp);
311+
fftshift(inp, true);
263312

264313
if (inverse)
265314
cvDFT(inp, out, CV_DXT_INVERSE, 0);
@@ -268,7 +317,7 @@ InferenceEngine::StatusCode FFTImpl::execute(std::vector<InferenceEngine::Blob::
268317
cvScale(out, out, 1.0 / sqrtf(cols * rows), 0);
269318

270319
if (centered)
271-
fftshift(out);
320+
fftshift(out, false);
272321

273322
cvReleaseMat(&inp);
274323
cvReleaseMat(&out);
@@ -284,7 +333,7 @@ InferenceEngine::StatusCode FFTImpl::execute(std::vector<InferenceEngine::Blob::
284333
cvSetData(out, reinterpret_cast<void*>(outData + d * planeSize), cols * 2 * sizeof(float));
285334

286335
if (centered)
287-
fftshift(inp);
336+
fftshift(inp, true);
288337

289338
if (inverse)
290339
cvDFT(inp, out, CV_DXT_INVERSE, 0);
@@ -293,7 +342,7 @@ InferenceEngine::StatusCode FFTImpl::execute(std::vector<InferenceEngine::Blob::
293342
cvScale(out, out, 1.0 / sqrtf(cols * rows), 0);
294343

295344
if (centered)
296-
fftshift(out);
345+
fftshift(out, false);
297346

298347
cvReleaseMat(&inp);
299348
cvReleaseMat(&out);
@@ -312,7 +361,7 @@ InferenceEngine::StatusCode FFTImpl::execute(std::vector<InferenceEngine::Blob::
312361
cvSetData(out, reinterpret_cast<void*>(outData + (b * planeSize * cols + col) * 2), cols * 2 * sizeof(float));
313362

314363
if (centered)
315-
fftshift(inp);
364+
fftshift(inp, true);
316365

317366
if (inverse)
318367
cvDFT(inp, out, CV_DXT_INVERSE, 0);
@@ -321,7 +370,7 @@ InferenceEngine::StatusCode FFTImpl::execute(std::vector<InferenceEngine::Blob::
321370
cvScale(out, out, 1.0 / sqrtf(rows), 0);
322371

323372
if (centered)
324-
fftshift(out);
373+
fftshift(out, false);
325374

326375
cvReleaseMat(&inp);
327376
cvReleaseMat(&out);

0 commit comments

Comments
 (0)