Skip to content

Commit 6beb81b

Browse files
authored
Merge pull request #119 from v0lta/refactor-networkx-compression
Refactor MNIST compression example
2 parents c01682e + e55586b commit 6beb81b

File tree

4 files changed

+211
-154
lines changed

4 files changed

+211
-154
lines changed

examples/network_compression/mnist_compression.py

Lines changed: 105 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,31 @@
1+
# /// script
2+
# requires-python = ">=3.14"
3+
# dependencies = [
4+
# "matplotlib",
5+
# "numpy",
6+
# "ptwt",
7+
# "tensorboard",
8+
# "torch",
9+
# "torchvision",
10+
# "pystow",
11+
# ]
12+
#
13+
# [tool.uv.sources]
14+
# ptwt = { path = "../../" }
15+
# ///
16+
117
# Originally created by moritz (wolter@cs.uni-bonn.de) on 17/12/2019
218
# at https://github.com/v0lta/Wavelet-network-compression/blob/master/mnist_compression.py
319
# based on https://github.com/pytorch/examples/blob/master/mnist/main.py
420

521
import argparse
622
import collections
23+
from typing import Literal
724

25+
from pathlib import Path
826
import matplotlib.pyplot as plt
927
import numpy as np
28+
import pystow
1029
import torch
1130
import torch.nn as nn
1231
import torch.nn.functional as F
@@ -16,64 +35,73 @@
1635
from torchvision import datasets, transforms
1736
from wavelet_linear import WaveletLayer
1837

19-
from ptwt.wavelets_learnable import ProductFilter
20-
38+
from ptwt.wavelets_learnable import ProductFilter, WaveletFilter
2139

22-
def compute_parameter_total(net):
23-
total = 0
24-
for p in net.parameters():
25-
if p.requires_grad:
26-
print(p.shape)
27-
total += np.prod(p.shape)
28-
return total
40+
HERE = Path(__file__).parent.resolve()
41+
MODULE = pystow.module("torchvision", "mnist")
2942

3043

3144
class Net(nn.Module):
32-
def __init__(self, compression, wavelet=None, wave_dropout=0.0):
33-
super(Net, self).__init__()
45+
def __init__(
46+
self,
47+
compression: Literal["None", "Wavelet"],
48+
*,
49+
wavelet: WaveletFilter | None = None,
50+
wave_dropout: float = 0.0,
51+
) -> None:
52+
super().__init__()
3453
self.conv1 = nn.Conv2d(1, 20, 5, 1)
3554
self.conv2 = nn.Conv2d(20, 50, 5, 1)
36-
self.dropout1 = nn.Dropout2d(0.25)
37-
self.dropout2 = nn.Dropout2d(0.5)
38-
self.wavelet = wavelet
39-
self.do_dropout = True
55+
self.max_pool_2s_k2 = torch.nn.MaxPool2d(2)
56+
57+
self.log_softmax = torch.nn.LogSoftmax(dim=1)
58+
self.flatten = torch.nn.Flatten(start_dim=1)
59+
self.relu = torch.nn.ReLU()
60+
4061
if compression == "None":
41-
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
42-
self.fc2 = torch.nn.Linear(500, 10)
62+
fc1 = torch.nn.Linear(4 * 4 * 50, 500)
63+
fc2 = torch.nn.Linear(500, 10)
64+
self.sequence = torch.nn.Sequential(
65+
self.conv1,
66+
self.max_pool_2s_k2,
67+
self.conv2,
68+
self.max_pool_2s_k2,
69+
nn.Dropout2d(0.25),
70+
self.flatten,
71+
fc1,
72+
self.relu,
73+
nn.Dropout2d(0.5),
74+
fc2,
75+
self.log_softmax,
76+
)
4377
elif compression == "Wavelet":
4478
assert wavelet is not None, "initial wavelet must be set."
45-
self.fc1 = WaveletLayer(
79+
self.wavelet = wavelet
80+
fc1 = WaveletLayer(
4681
init_wavelet=wavelet, scales=6, depth=800, p_drop=wave_dropout
4782
)
48-
self.fc2 = torch.nn.Linear(800, 10)
49-
self.do_dropout = False
83+
fc2 = torch.nn.Linear(800, 10)
84+
self.sequence = torch.nn.Sequential(
85+
self.conv1,
86+
self.max_pool_2s_k2,
87+
self.conv2,
88+
self.max_pool_2s_k2,
89+
self.flatten,
90+
fc1,
91+
self.relu,
92+
fc2,
93+
self.log_softmax,
94+
)
5095
else:
51-
raise ValueError("Compression type Unknown.")
52-
53-
def forward(self, x):
54-
x = self.conv1(x)
55-
# x = F.relu(x)
56-
x = F.max_pool2d(x, 2)
57-
x = self.conv2(x)
58-
x = F.max_pool2d(x, 2)
59-
if self.do_dropout:
60-
x = self.dropout1(x)
61-
x = torch.flatten(x, 1)
62-
x = self.fc1(x)
63-
x = F.relu(x)
64-
if self.do_dropout:
65-
x = self.dropout2(x)
66-
x = self.fc2(x)
67-
output = F.log_softmax(x, dim=1)
68-
return output
96+
raise ValueError(f"invalid compression: {compression}")
97+
98+
def forward(self, x: torch.Tensor) -> torch.Tensor:
99+
return self.sequence(x)
69100

70101
def wavelet_loss(self):
71102
if self.wavelet is None:
72-
return torch.tensor(0.0), torch.tensor(0.0)
73-
else:
74-
acl, _, _ = self.fc1.wavelet.alias_cancellation_loss()
75-
prl, _, _ = self.fc1.wavelet.perfect_reconstruction_loss()
76-
return acl, prl
103+
return torch.tensor(0.0)
104+
return self.wavelet.wavelet_loss()
77105

78106

79107
def train(args, model, device, train_loader, optimizer, epoch):
@@ -82,27 +110,23 @@ def train(args, model, device, train_loader, optimizer, epoch):
82110
data, target = data.to(device), target.to(device)
83111
optimizer.zero_grad()
84112
output = model(data)
85-
nll_loss = F.nll_loss(output, target)
113+
loss = F.nll_loss(output, target)
86114
if args.compression == "Wavelet":
87-
acl, prl = model.wavelet_loss()
88-
wvl = acl + prl
89-
loss = nll_loss + wvl * args.wave_loss_weight
90-
else:
91-
wvl = torch.tensor(0.0)
92-
loss = nll_loss
115+
wvl = model.wavelet_loss()
116+
loss = loss + wvl * args.wave_loss_weight
93117
loss.backward()
94118
optimizer.step()
95119
if batch_idx % args.log_interval == 0:
96-
print(
97-
"Train Epoch: {} [{}/{} ({:.0f}%)], Loss: {:.6f}, wvl-Loss: {:.6f}".format(
98-
epoch,
99-
batch_idx * len(data),
100-
len(train_loader.dataset),
101-
100.0 * batch_idx / len(train_loader),
102-
nll_loss.item(),
103-
wvl.item(),
104-
)
120+
msg = "Train Epoch: {} [{}/{} ({:.0f}%)], Loss: {:.6f}".format(
121+
epoch,
122+
batch_idx * len(data),
123+
len(train_loader.dataset),
124+
100.0 * batch_idx / len(train_loader),
125+
loss.item(),
105126
)
127+
if args.compression == "Wavelet":
128+
msg += f", wvl-loss: {wvl.item():.6f}"
129+
print(msg)
106130

107131

108132
def test(args, model, device, test_loader, test_writer, epoch):
@@ -122,8 +146,8 @@ def test(args, model, device, test_loader, test_writer, epoch):
122146
correct += pred.eq(target.view_as(pred)).sum().item()
123147

124148
test_loss /= len(test_loader.dataset)
125-
acl, prl = model.wavelet_loss()
126-
wvl_loss = acl + prl
149+
150+
wvl_loss = model.wavelet_loss()
127151

128152
print(
129153
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
@@ -221,16 +245,24 @@ def main():
221245

222246
args = parser.parse_args()
223247
print(args)
224-
use_cuda = not args.no_cuda and torch.cuda.is_available()
248+
249+
if args.no_cuda:
250+
device = "cpu"
251+
elif torch.cuda.is_available():
252+
device = "cuda"
253+
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
254+
device = "mps"
255+
else:
256+
device = "cpu"
225257

226258
torch.manual_seed(args.seed)
227259

228-
device = torch.device("cuda" if use_cuda else "cpu")
260+
device = torch.device(device)
229261

230-
kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
262+
kwargs = {"num_workers": 1, "pin_memory": True} if device == "cuda" else {}
231263
train_loader = torch.utils.data.DataLoader(
232264
datasets.MNIST(
233-
"../data",
265+
MODULE.base,
234266
train=True,
235267
download=True,
236268
transform=transforms.Compose(
@@ -243,7 +275,7 @@ def main():
243275
)
244276
test_loader = torch.utils.data.DataLoader(
245277
datasets.MNIST(
246-
"../data",
278+
MODULE.base,
247279
train=False,
248280
transform=transforms.Compose(
249281
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
@@ -300,20 +332,20 @@ def main():
300332
if args.save_model:
301333
torch.save(model.state_dict(), "mnist_cnn.pt")
302334

303-
print(compute_parameter_total(model))
335+
n_params = sum(np.prod(p.shape) for p in model.parameters() if p.requires_grad)
336+
print(f"the model has {n_params:,} parameters")
304337

305338
# plt.semilogy(test_wvl_lst)
306339
# plt.semilogy(test_acc_lst)
307340
# plt.legend(['wavlet loss', 'accuracy'])
308341
# plt.show()
309342

310-
plt.plot(model.fc1.wavelet.dec_lo.detach().cpu().numpy(), "-*")
311-
plt.plot(model.fc1.wavelet.dec_hi.detach().cpu().numpy(), "-*")
312-
plt.plot(model.fc1.wavelet.rec_lo.detach().cpu().numpy(), "-*")
313-
plt.plot(model.fc1.wavelet.rec_hi.detach().cpu().numpy(), "-*")
343+
plt.plot(model.wavelet.filter_bank.dec_lo.detach().cpu().numpy(), "-*")
344+
plt.plot(model.wavelet.filter_bank.dec_hi.detach().cpu().numpy(), "-*")
345+
plt.plot(model.wavelet.filter_bank.rec_lo.detach().cpu().numpy(), "-*")
346+
plt.plot(model.wavelet.filter_bank.rec_hi.detach().cpu().numpy(), "-*")
314347
plt.legend(["H_0", "H_1", "F_0", "F_1"])
315-
plt.show()
316-
print("done")
348+
plt.savefig(HERE.joinpath("plot.svg"))
317349

318350

319351
if __name__ == "__main__":
Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
#### Adaptive Wavelets
2+
23
```mnist_compression.py``` trains a CNN on MNIST with an adaptive-wavelet
34
compressed linear layer. The wavelets in the linear layer are learned using gradient descent.
45

56
See https://arxiv.org/pdf/2004.09569v3.pdf for a detailed description of the method.
67

7-
Running this example requires the following steps:
8-
- clone this repository,
9-
- install `ptwt`,
10-
- and execute ```python mnist_compression.py```.
8+
Running this example requires the following steps, which takes care of installing everything
9+
10+
```console
11+
$ git clone https://github.com/v0lta/PyTorch-Wavelet-Toolbox.git
12+
$ cd PyTorch-Wavelet-Toolbox/examples/network_compression
13+
$ uv run mnist_compression.py
14+
```

0 commit comments

Comments
 (0)