1+ # /// script
2+ # requires-python = ">=3.14"
3+ # dependencies = [
4+ # "matplotlib",
5+ # "numpy",
6+ # "ptwt",
7+ # "tensorboard",
8+ # "torch",
9+ # "torchvision",
10+ # "pystow",
11+ # ]
12+ #
13+ # [tool.uv.sources]
14+ # ptwt = { path = "../../" }
15+ # ///
16+
117# Originally created by moritz (wolter@cs.uni-bonn.de) on 17/12/2019
218# at https://github.com/v0lta/Wavelet-network-compression/blob/master/mnist_compression.py
319# based on https://github.com/pytorch/examples/blob/master/mnist/main.py
420
521import argparse
622import collections
23+ from typing import Literal
724
25+ from pathlib import Path
826import matplotlib .pyplot as plt
927import numpy as np
28+ import pystow
1029import torch
1130import torch .nn as nn
1231import torch .nn .functional as F
1635from torchvision import datasets , transforms
1736from wavelet_linear import WaveletLayer
1837
19- from ptwt .wavelets_learnable import ProductFilter
20-
38+ from ptwt .wavelets_learnable import ProductFilter , WaveletFilter
2139
22- def compute_parameter_total (net ):
23- total = 0
24- for p in net .parameters ():
25- if p .requires_grad :
26- print (p .shape )
27- total += np .prod (p .shape )
28- return total
40+ HERE = Path (__file__ ).parent .resolve ()
41+ MODULE = pystow .module ("torchvision" , "mnist" )
2942
3043
3144class Net (nn .Module ):
32- def __init__ (self , compression , wavelet = None , wave_dropout = 0.0 ):
33- super (Net , self ).__init__ ()
45+ def __init__ (
46+ self ,
47+ compression : Literal ["None" , "Wavelet" ],
48+ * ,
49+ wavelet : WaveletFilter | None = None ,
50+ wave_dropout : float = 0.0 ,
51+ ) -> None :
52+ super ().__init__ ()
3453 self .conv1 = nn .Conv2d (1 , 20 , 5 , 1 )
3554 self .conv2 = nn .Conv2d (20 , 50 , 5 , 1 )
36- self .dropout1 = nn .Dropout2d (0.25 )
37- self .dropout2 = nn .Dropout2d (0.5 )
38- self .wavelet = wavelet
39- self .do_dropout = True
55+ self .max_pool_2s_k2 = torch .nn .MaxPool2d (2 )
56+
57+ self .log_softmax = torch .nn .LogSoftmax (dim = 1 )
58+ self .flatten = torch .nn .Flatten (start_dim = 1 )
59+ self .relu = torch .nn .ReLU ()
60+
4061 if compression == "None" :
41- self .fc1 = torch .nn .Linear (4 * 4 * 50 , 500 )
42- self .fc2 = torch .nn .Linear (500 , 10 )
62+ fc1 = torch .nn .Linear (4 * 4 * 50 , 500 )
63+ fc2 = torch .nn .Linear (500 , 10 )
64+ self .sequence = torch .nn .Sequential (
65+ self .conv1 ,
66+ self .max_pool_2s_k2 ,
67+ self .conv2 ,
68+ self .max_pool_2s_k2 ,
69+ nn .Dropout2d (0.25 ),
70+ self .flatten ,
71+ fc1 ,
72+ self .relu ,
73+ nn .Dropout2d (0.5 ),
74+ fc2 ,
75+ self .log_softmax ,
76+ )
4377 elif compression == "Wavelet" :
4478 assert wavelet is not None , "initial wavelet must be set."
45- self .fc1 = WaveletLayer (
79+ self .wavelet = wavelet
80+ fc1 = WaveletLayer (
4681 init_wavelet = wavelet , scales = 6 , depth = 800 , p_drop = wave_dropout
4782 )
48- self .fc2 = torch .nn .Linear (800 , 10 )
49- self .do_dropout = False
83+ fc2 = torch .nn .Linear (800 , 10 )
84+ self .sequence = torch .nn .Sequential (
85+ self .conv1 ,
86+ self .max_pool_2s_k2 ,
87+ self .conv2 ,
88+ self .max_pool_2s_k2 ,
89+ self .flatten ,
90+ fc1 ,
91+ self .relu ,
92+ fc2 ,
93+ self .log_softmax ,
94+ )
5095 else :
51- raise ValueError ("Compression type Unknown." )
52-
53- def forward (self , x ):
54- x = self .conv1 (x )
55- # x = F.relu(x)
56- x = F .max_pool2d (x , 2 )
57- x = self .conv2 (x )
58- x = F .max_pool2d (x , 2 )
59- if self .do_dropout :
60- x = self .dropout1 (x )
61- x = torch .flatten (x , 1 )
62- x = self .fc1 (x )
63- x = F .relu (x )
64- if self .do_dropout :
65- x = self .dropout2 (x )
66- x = self .fc2 (x )
67- output = F .log_softmax (x , dim = 1 )
68- return output
96+ raise ValueError (f"invalid compression: { compression } " )
97+
98+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
99+ return self .sequence (x )
69100
70101 def wavelet_loss (self ):
71102 if self .wavelet is None :
72- return torch .tensor (0.0 ), torch .tensor (0.0 )
73- else :
74- acl , _ , _ = self .fc1 .wavelet .alias_cancellation_loss ()
75- prl , _ , _ = self .fc1 .wavelet .perfect_reconstruction_loss ()
76- return acl , prl
103+ return torch .tensor (0.0 )
104+ return self .wavelet .wavelet_loss ()
77105
78106
79107def train (args , model , device , train_loader , optimizer , epoch ):
@@ -82,27 +110,23 @@ def train(args, model, device, train_loader, optimizer, epoch):
82110 data , target = data .to (device ), target .to (device )
83111 optimizer .zero_grad ()
84112 output = model (data )
85- nll_loss = F .nll_loss (output , target )
113+ loss = F .nll_loss (output , target )
86114 if args .compression == "Wavelet" :
87- acl , prl = model .wavelet_loss ()
88- wvl = acl + prl
89- loss = nll_loss + wvl * args .wave_loss_weight
90- else :
91- wvl = torch .tensor (0.0 )
92- loss = nll_loss
115+ wvl = model .wavelet_loss ()
116+ loss = loss + wvl * args .wave_loss_weight
93117 loss .backward ()
94118 optimizer .step ()
95119 if batch_idx % args .log_interval == 0 :
96- print (
97- "Train Epoch: {} [{}/{} ({:.0f}%)], Loss: {:.6f}, wvl-Loss: {:.6f}" .format (
98- epoch ,
99- batch_idx * len (data ),
100- len (train_loader .dataset ),
101- 100.0 * batch_idx / len (train_loader ),
102- nll_loss .item (),
103- wvl .item (),
104- )
120+ msg = "Train Epoch: {} [{}/{} ({:.0f}%)], Loss: {:.6f}" .format (
121+ epoch ,
122+ batch_idx * len (data ),
123+ len (train_loader .dataset ),
124+ 100.0 * batch_idx / len (train_loader ),
125+ loss .item (),
105126 )
127+ if args .compression == "Wavelet" :
128+ msg += f", wvl-loss: { wvl .item ():.6f} "
129+ print (msg )
106130
107131
108132def test (args , model , device , test_loader , test_writer , epoch ):
@@ -122,8 +146,8 @@ def test(args, model, device, test_loader, test_writer, epoch):
122146 correct += pred .eq (target .view_as (pred )).sum ().item ()
123147
124148 test_loss /= len (test_loader .dataset )
125- acl , prl = model . wavelet_loss ()
126- wvl_loss = acl + prl
149+
150+ wvl_loss = model . wavelet_loss ()
127151
128152 print (
129153 "\n Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n " .format (
@@ -221,16 +245,24 @@ def main():
221245
222246 args = parser .parse_args ()
223247 print (args )
224- use_cuda = not args .no_cuda and torch .cuda .is_available ()
248+
249+ if args .no_cuda :
250+ device = "cpu"
251+ elif torch .cuda .is_available ():
252+ device = "cuda"
253+ elif torch .backends .mps .is_available () and torch .backends .mps .is_built ():
254+ device = "mps"
255+ else :
256+ device = "cpu"
225257
226258 torch .manual_seed (args .seed )
227259
228- device = torch .device ("cuda" if use_cuda else "cpu" )
260+ device = torch .device (device )
229261
230- kwargs = {"num_workers" : 1 , "pin_memory" : True } if use_cuda else {}
262+ kwargs = {"num_workers" : 1 , "pin_memory" : True } if device == "cuda" else {}
231263 train_loader = torch .utils .data .DataLoader (
232264 datasets .MNIST (
233- "../data" ,
265+ MODULE . base ,
234266 train = True ,
235267 download = True ,
236268 transform = transforms .Compose (
@@ -243,7 +275,7 @@ def main():
243275 )
244276 test_loader = torch .utils .data .DataLoader (
245277 datasets .MNIST (
246- "../data" ,
278+ MODULE . base ,
247279 train = False ,
248280 transform = transforms .Compose (
249281 [transforms .ToTensor (), transforms .Normalize ((0.1307 ,), (0.3081 ,))]
@@ -300,20 +332,20 @@ def main():
300332 if args .save_model :
301333 torch .save (model .state_dict (), "mnist_cnn.pt" )
302334
303- print (compute_parameter_total (model ))
335+ n_params = sum (np .prod (p .shape ) for p in model .parameters () if p .requires_grad )
336+ print (f"the model has { n_params :,} parameters" )
304337
305338 # plt.semilogy(test_wvl_lst)
306339 # plt.semilogy(test_acc_lst)
307340 # plt.legend(['wavlet loss', 'accuracy'])
308341 # plt.show()
309342
310- plt .plot (model .fc1 . wavelet .dec_lo .detach ().cpu ().numpy (), "-*" )
311- plt .plot (model .fc1 . wavelet .dec_hi .detach ().cpu ().numpy (), "-*" )
312- plt .plot (model .fc1 . wavelet .rec_lo .detach ().cpu ().numpy (), "-*" )
313- plt .plot (model .fc1 . wavelet .rec_hi .detach ().cpu ().numpy (), "-*" )
343+ plt .plot (model .wavelet . filter_bank .dec_lo .detach ().cpu ().numpy (), "-*" )
344+ plt .plot (model .wavelet . filter_bank .dec_hi .detach ().cpu ().numpy (), "-*" )
345+ plt .plot (model .wavelet . filter_bank .rec_lo .detach ().cpu ().numpy (), "-*" )
346+ plt .plot (model .wavelet . filter_bank .rec_hi .detach ().cpu ().numpy (), "-*" )
314347 plt .legend (["H_0" , "H_1" , "F_0" , "F_1" ])
315- plt .show ()
316- print ("done" )
348+ plt .savefig (HERE .joinpath ("plot.svg" ))
317349
318350
319351if __name__ == "__main__" :
0 commit comments