Skip to content
This repository was archived by the owner on May 29, 2023. It is now read-only.

Commit 29f0829

Browse files
authored
SparseConv from Open3D (#29)
1 parent 18edc45 commit 29f0829

18 files changed

+1066
-125
lines changed

.github/workflows/main.yml

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ on:
1212

1313
env:
1414
OPENVINO_VERSION: 2021.4.2
15-
VERSION: 2021.4.2.3
15+
VERSION: 2021.4.2.4
1616
DIST_VERSION: 2021.4.752
1717
DIST_WIN: https://registrationcenter-download.intel.com/akdlm/irc_nas/18320/w_openvino_toolkit_p_2021.4.752.exe
1818
DIST_MAC: https://registrationcenter-download.intel.com/akdlm/irc_nas/18317/m_openvino_toolkit_p_2021.4.752.dmg
@@ -134,9 +134,6 @@ jobs:
134134
test_lnx:
135135
needs: build_lnx
136136
runs-on: ubuntu-18.04
137-
strategy:
138-
matrix:
139-
torch-version: [1.7.0, 1.10.0]
140137

141138
steps:
142139
- uses: actions/checkout@v2
@@ -149,7 +146,7 @@ jobs:
149146
run: |
150147
sudo apt-get install -y python3-setuptools libopencv-dev
151148
python3 -m pip install --upgrade pip
152-
python3 -m pip install torch==${{ matrix.torch-version }} torchvision
149+
python3 -m pip install -r tests/requirements.txt
153150
python3 -m pip install -U protobuf
154151
python3 -m pip install openvino-dev[onnx]==${{env.OPENVINO_VERSION}}
155152
@@ -161,7 +158,7 @@ jobs:
161158
162159
- name: Test
163160
run: |
164-
python3 -m unittest run_tests.py
161+
python3 -m pytest tests/run_tests.py
165162
166163
publish:
167164
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
@@ -177,4 +174,4 @@ jobs:
177174
run: |
178175
python3 -m pip install --upgrade pip
179176
python3 -m pip install twine
180-
python3 -m twine upload wheel*/*.whl --skip-existing
177+
python3 -m twine upload --repository testpypi wheel*/*.whl --skip-existing

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ Repository with guides to enable some layers from PyTorch in Intel OpenVINO:
66
* [torch.fft](examples/fft)
77
* [nn.functional.grid_sample](https://github.com/dkurt/openvino_pytorch_layers/tree/master/examples/grid_sample)
88
* [torchvision.ops.DeformConv2d](examples/deformable_conv)
9+
* [SparseConv](examples/sparse_conv) from [Open3D](https://github.com/isl-org/Open3D)
910

1011

1112
## OpenVINO Model Optimizer extension
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
class CalculateGrid(torch.autograd.Function):
6+
@staticmethod
7+
def symbolic(g, in_positions):
8+
return g.op("org.open3d::calculate_grid", in_positions)
9+
10+
@staticmethod
11+
def forward(self, in_positions):
12+
filter = torch.Tensor([[-1, -1, -1], [-1, -1, 0], [-1, 0, -1], [-1, 0, 0],
13+
[0, -1, -1], [0, -1, 0], [0, 0, -1],
14+
[0, 0, 0]]).to(in_positions.device)
15+
16+
out_pos = in_positions.long().repeat(1, filter.shape[0]).reshape(-1, 3)
17+
filter = filter.repeat(in_positions.shape[0], 1)
18+
19+
out_pos = out_pos + filter
20+
out_pos = out_pos[out_pos.min(1).values >= 0]
21+
out_pos = out_pos[(~((out_pos.long() % 2).bool()).any(1))]
22+
out_pos = torch.unique(out_pos, dim=0)
23+
24+
return out_pos + 0.5
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import numpy as np
2+
import argparse
3+
import torch
4+
import torch.nn as nn
5+
from torch.autograd import Variable
6+
from .calculate_grid import CalculateGrid
7+
8+
9+
class MyModel(nn.Module):
10+
def __init__(self):
11+
super(MyModel, self).__init__()
12+
self.calculate_grid = CalculateGrid()
13+
14+
def forward(self, x):
15+
return self.calculate_grid.apply(x)
16+
17+
18+
def export(num_points, max_grid_extent):
19+
# Generate a list of unique positions and add a mantissa
20+
np.random.seed(32)
21+
torch.manual_seed(11)
22+
23+
inp_pos = np.random.randint(0, max_grid_extent, [num_points, 3])
24+
inp_pos = torch.tensor(inp_pos) + torch.rand(inp_pos.shape, dtype=torch.float32) # [0, 1)
25+
26+
model = MyModel()
27+
with torch.no_grad():
28+
torch.onnx.export(model, (inp_pos), 'model.onnx',
29+
input_names=['input'],
30+
output_names=['output'],
31+
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
32+
33+
ref = model(inp_pos).detach().numpy()
34+
35+
# Pad values with espetial end line (-1, 0, 0) and zeros
36+
ref = np.concatenate((ref, [[-1, 0, 0]]))
37+
ref = np.pad(ref, ((0, inp_pos.shape[0] - ref.shape[0]), (0, 0)))
38+
39+
np.save('inp', inp_pos.detach().numpy())
40+
np.save('ref', ref)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import numpy as np
2+
import argparse
3+
import torch
4+
import torch.nn as nn
5+
from torch.autograd import Variable
6+
from .sparse_conv import SparseConvONNX, SparseConvTransposeONNX
7+
8+
9+
def export(num_inp_points, num_out_points, max_grid_extent, in_channels,
10+
filters, kernel_size, normalize, transpose):
11+
np.random.seed(324)
12+
torch.manual_seed(32)
13+
14+
if transpose:
15+
sparse_conv = SparseConvTransposeONNX(in_channels=in_channels,
16+
filters=filters,
17+
kernel_size=kernel_size,
18+
use_bias=False,
19+
normalize=False)
20+
else:
21+
sparse_conv = SparseConvONNX(in_channels=in_channels,
22+
filters=filters,
23+
kernel_size=kernel_size,
24+
use_bias=False,
25+
normalize=False)
26+
27+
# Generate a list of unique positions and add a mantissa
28+
def gen_pos(num_points):
29+
inp_pos = np.random.randint(0, max_grid_extent, [num_points, 3])
30+
inp_pos = np.unique(inp_pos, axis=0).astype(np.float32)
31+
inp_pos = torch.tensor(inp_pos) + torch.rand(inp_pos.shape, dtype=torch.float32) # [0, 1)
32+
return inp_pos
33+
34+
inp_pos = gen_pos(num_inp_points)
35+
out_pos = gen_pos(num_out_points) if num_out_points else inp_pos
36+
37+
features = torch.randn([inp_pos.shape[0], in_channels])
38+
39+
voxel_size = torch.tensor(1.0)
40+
sparse_conv.eval()
41+
42+
new_kernel = torch.randn(sparse_conv.state_dict()["kernel"].shape)
43+
sparse_conv.load_state_dict({"kernel": new_kernel,
44+
"offset": sparse_conv.state_dict()["offset"]})
45+
46+
with torch.no_grad():
47+
torch.onnx.export(sparse_conv, (features, inp_pos, out_pos, voxel_size), 'model.onnx',
48+
input_names=['input', 'input1', 'input2', 'voxel_size'],
49+
output_names=['output'],
50+
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
51+
52+
ref = sparse_conv(features, inp_pos, out_pos, voxel_size)
53+
np.save('inp', features.detach().numpy())
54+
np.save('inp1', inp_pos.detach().numpy())
55+
np.save('inp2', out_pos.detach().numpy())
56+
np.save('ref', ref.detach().numpy())
57+
58+
59+
if __name__ == "__main__":
60+
parser = argparse.ArgumentParser(description='Generate ONNX model and test data')
61+
parser.add_argument('--num_points', type=int)
62+
parser.add_argument('--max_grid_extent', type=int)
63+
parser.add_argument('--in_channels', type=int)
64+
parser.add_argument('--filters', type=int)
65+
parser.add_argument('--kernel_size', type=int)
66+
args = parser.parse_args()
67+
68+
export(args.num_points, args.max_grid_extent,
69+
args.in_channels, args.filters, args.kernel_size)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from open3d.ml.torch.layers import SparseConv, SparseConvTranspose
5+
6+
class SparseConvFunc(torch.autograd.Function):
7+
@staticmethod
8+
def symbolic(g, cls, feat, in_pos, out_pos, voxel_size):
9+
kernel = cls.state_dict()["kernel"]
10+
offset = cls.state_dict()["offset"]
11+
kernel = g.op("Constant", value_t=kernel)
12+
offset = g.op("Constant", value_t=offset)
13+
return g.op("org.open3d::SparseConv", feat, in_pos, out_pos, kernel, offset)
14+
15+
@staticmethod
16+
def forward(self, cls, feat, in_pos, out_pos, voxel_size):
17+
return cls.origin_forward(feat, in_pos, out_pos, voxel_size)
18+
19+
20+
class SparseConvONNX(SparseConv):
21+
"""
22+
This is a support class which helps export network with SparseConv in ONNX format.
23+
"""
24+
def __init__(self, *args, **kwargs):
25+
super().__init__(*args, **kwargs)
26+
self.origin_forward = super().forward
27+
28+
def forward(self, feat, in_pos, out_pos, voxel_size):
29+
return SparseConvFunc.apply(self, feat, in_pos, out_pos, voxel_size)
30+
31+
32+
class SparseConvTransposeFunc(torch.autograd.Function):
33+
@staticmethod
34+
def symbolic(g, cls, feat, in_pos, out_pos, voxel_size):
35+
kernel = cls.state_dict()["kernel"]
36+
offset = cls.state_dict()["offset"]
37+
kernel = g.op("Constant", value_t=kernel)
38+
offset = g.op("Constant", value_t=offset)
39+
return g.op("org.open3d::SparseConvTranspose", feat, in_pos, out_pos, kernel, offset)
40+
41+
@staticmethod
42+
def forward(self, cls, feat, in_pos, out_pos, voxel_size):
43+
return cls.origin_forward(feat, in_pos, out_pos, voxel_size)
44+
45+
46+
class SparseConvTransposeONNX(SparseConvTranspose):
47+
"""
48+
This is a support class which helps export network with SparseConvTranspose in ONNX format.
49+
"""
50+
def __init__(self, *args, **kwargs):
51+
super().__init__(*args, **kwargs)
52+
self.origin_forward = super().forward
53+
54+
def forward(self, feat, in_pos, out_pos, voxel_size):
55+
return SparseConvTransposeFunc.apply(self, feat, in_pos, out_pos, voxel_size)

run_tests.py

Lines changed: 0 additions & 118 deletions
This file was deleted.

tests/requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
torch==1.8.1
2+
torchvision==0.9.1
3+
open3d==0.14.1
4+
tensorboard
5+
pytest

0 commit comments

Comments
 (0)