|
| 1 | +\begin{Verbatim}[commandchars=\\\{\},codes={\catcode`\$=3\catcode`\^=7\catcode`\_=8\relax}] |
| 2 | +\PYG{k+kn}{import} \PYG{n+nn}{numpy} \PYG{k}{as} \PYG{n+nn}{np} |
| 3 | +\PYG{k+kn}{from} \PYG{n+nn}{sklearn.datasets} \PYG{k+kn}{import} \PYG{n}{fetch\PYGZus{}openml} |
| 4 | +\PYG{k+kn}{from} \PYG{n+nn}{sklearn.model\PYGZus{}selection} \PYG{k+kn}{import} \PYG{n}{train\PYGZus{}test\PYGZus{}split} |
| 5 | +\PYG{k+kn}{import} \PYG{n+nn}{matplotlib.pyplot} \PYG{k}{as} \PYG{n+nn}{plt} |
| 6 | + |
| 7 | +\PYG{c+c1}{\PYGZsh{} Load and preprocess MNIST} |
| 8 | +\PYG{k}{def} \PYG{n+nf}{load\PYGZus{}binarized\PYGZus{}mnist}\PYG{p}{():} |
| 9 | + \PYG{n+nb}{print}\PYG{p}{(}\PYG{l+s+s2}{\PYGZdq{}Downloading MNIST...\PYGZdq{}}\PYG{p}{)} |
| 10 | + \PYG{n}{mnist} \PYG{o}{=} \PYG{n}{fetch\PYGZus{}openml}\PYG{p}{(}\PYG{l+s+s1}{\PYGZsq{}mnist\PYGZus{}784\PYGZsq{}}\PYG{p}{,} \PYG{n}{version}\PYG{o}{=}\PYG{l+m+mi}{1}\PYG{p}{)} |
| 11 | + \PYG{n}{X} \PYG{o}{=} \PYG{n}{mnist}\PYG{o}{.}\PYG{n}{data}\PYG{o}{.}\PYG{n}{astype}\PYG{p}{(}\PYG{n}{np}\PYG{o}{.}\PYG{n}{float32}\PYG{p}{)} \PYG{o}{/} \PYG{l+m+mf}{255.0} |
| 12 | + \PYG{n}{X} \PYG{o}{=} \PYG{p}{(}\PYG{n}{X} \PYG{o}{\PYGZgt{}} \PYG{l+m+mf}{0.5}\PYG{p}{)}\PYG{o}{.}\PYG{n}{astype}\PYG{p}{(}\PYG{n}{np}\PYG{o}{.}\PYG{n}{float32}\PYG{p}{)} \PYG{c+c1}{\PYGZsh{} Binarize} |
| 13 | + \PYG{k}{return} \PYG{n}{X} |
| 14 | + |
| 15 | +\PYG{k}{class} \PYG{n+nc}{RBM}\PYG{p}{:} |
| 16 | + \PYG{k}{def} \PYG{n+nf+fm}{\PYGZus{}\PYGZus{}init\PYGZus{}\PYGZus{}}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{p}{,} \PYG{n}{n\PYGZus{}visible}\PYG{p}{,} \PYG{n}{n\PYGZus{}hidden}\PYG{p}{,} \PYG{n}{learning\PYGZus{}rate}\PYG{o}{=}\PYG{l+m+mf}{0.1}\PYG{p}{):} |
| 17 | + \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{n\PYGZus{}visible} \PYG{o}{=} \PYG{n}{n\PYGZus{}visible} |
| 18 | + \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{n\PYGZus{}hidden} \PYG{o}{=} \PYG{n}{n\PYGZus{}hidden} |
| 19 | + \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{learning\PYGZus{}rate} \PYG{o}{=} \PYG{n}{learning\PYGZus{}rate} |
| 20 | + |
| 21 | + \PYG{c+c1}{\PYGZsh{} Initialize weights and biases} |
| 22 | + \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{W} \PYG{o}{=} \PYG{n}{np}\PYG{o}{.}\PYG{n}{random}\PYG{o}{.}\PYG{n}{normal}\PYG{p}{(}\PYG{l+m+mi}{0}\PYG{p}{,} \PYG{l+m+mf}{0.01}\PYG{p}{,} \PYG{n}{size}\PYG{o}{=}\PYG{p}{(}\PYG{n}{n\PYGZus{}visible}\PYG{p}{,} \PYG{n}{n\PYGZus{}hidden}\PYG{p}{))} |
| 23 | + \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{v\PYGZus{}bias} \PYG{o}{=} \PYG{n}{np}\PYG{o}{.}\PYG{n}{zeros}\PYG{p}{(}\PYG{n}{n\PYGZus{}visible}\PYG{p}{)} |
| 24 | + \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{h\PYGZus{}bias} \PYG{o}{=} \PYG{n}{np}\PYG{o}{.}\PYG{n}{zeros}\PYG{p}{(}\PYG{n}{n\PYGZus{}hidden}\PYG{p}{)} |
| 25 | + |
| 26 | + \PYG{k}{def} \PYG{n+nf}{sigmoid}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{p}{,} \PYG{n}{x}\PYG{p}{):} |
| 27 | + \PYG{k}{return} \PYG{l+m+mi}{1} \PYG{o}{/} \PYG{p}{(}\PYG{l+m+mi}{1} \PYG{o}{+} \PYG{n}{np}\PYG{o}{.}\PYG{n}{exp}\PYG{p}{(}\PYG{o}{\PYGZhy{}}\PYG{n}{x}\PYG{p}{))} |
| 28 | + |
| 29 | + \PYG{k}{def} \PYG{n+nf}{sample}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{p}{,} \PYG{n}{probs}\PYG{p}{):} |
| 30 | + \PYG{k}{return} \PYG{p}{(}\PYG{n}{np}\PYG{o}{.}\PYG{n}{random}\PYG{o}{.}\PYG{n}{rand}\PYG{p}{(}\PYG{o}{*}\PYG{n}{probs}\PYG{o}{.}\PYG{n}{shape}\PYG{p}{)} \PYG{o}{\PYGZlt{}} \PYG{n}{probs}\PYG{p}{)}\PYG{o}{.}\PYG{n}{astype}\PYG{p}{(}\PYG{n}{np}\PYG{o}{.}\PYG{n}{float32}\PYG{p}{)} |
| 31 | + |
| 32 | + \PYG{k}{def} \PYG{n+nf}{train}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{p}{,} \PYG{n}{data}\PYG{p}{,} \PYG{n}{epochs}\PYG{o}{=}\PYG{l+m+mi}{10}\PYG{p}{,} \PYG{n}{batch\PYGZus{}size}\PYG{o}{=}\PYG{l+m+mi}{64}\PYG{p}{):} |
| 33 | + \PYG{n}{n\PYGZus{}samples} \PYG{o}{=} \PYG{n}{data}\PYG{o}{.}\PYG{n}{shape}\PYG{p}{[}\PYG{l+m+mi}{0}\PYG{p}{]} |
| 34 | + \PYG{c+c1}{\PYGZsh{} Convert the DataFrame to a NumPy array to avoid the KeyError.} |
| 35 | + \PYG{n}{data} \PYG{o}{=} \PYG{n}{data}\PYG{o}{.}\PYG{n}{to\PYGZus{}numpy}\PYG{p}{()} |
| 36 | + \PYG{k}{for} \PYG{n}{epoch} \PYG{o+ow}{in} \PYG{n+nb}{range}\PYG{p}{(}\PYG{n}{epochs}\PYG{p}{):} |
| 37 | + \PYG{n}{np}\PYG{o}{.}\PYG{n}{random}\PYG{o}{.}\PYG{n}{shuffle}\PYG{p}{(}\PYG{n}{data}\PYG{p}{)} |
| 38 | + \PYG{n}{epoch\PYGZus{}error} \PYG{o}{=} \PYG{l+m+mi}{0} |
| 39 | + |
| 40 | + \PYG{k}{for} \PYG{n}{i} \PYG{o+ow}{in} \PYG{n+nb}{range}\PYG{p}{(}\PYG{l+m+mi}{0}\PYG{p}{,} \PYG{n}{n\PYGZus{}samples}\PYG{p}{,} \PYG{n}{batch\PYGZus{}size}\PYG{p}{):} |
| 41 | + \PYG{n}{v0} \PYG{o}{=} \PYG{n}{data}\PYG{p}{[}\PYG{n}{i}\PYG{p}{:}\PYG{n}{i} \PYG{o}{+} \PYG{n}{batch\PYGZus{}size}\PYG{p}{]} |
| 42 | + \PYG{n}{h0\PYGZus{}prob} \PYG{o}{=} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{sigmoid}\PYG{p}{(}\PYG{n}{np}\PYG{o}{.}\PYG{n}{dot}\PYG{p}{(}\PYG{n}{v0}\PYG{p}{,} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{W}\PYG{p}{)} \PYG{o}{+} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{h\PYGZus{}bias}\PYG{p}{)} |
| 43 | + \PYG{n}{h0\PYGZus{}sample} \PYG{o}{=} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{sample}\PYG{p}{(}\PYG{n}{h0\PYGZus{}prob}\PYG{p}{)} |
| 44 | + |
| 45 | + \PYG{n}{v1\PYGZus{}prob} \PYG{o}{=} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{sigmoid}\PYG{p}{(}\PYG{n}{np}\PYG{o}{.}\PYG{n}{dot}\PYG{p}{(}\PYG{n}{h0\PYGZus{}sample}\PYG{p}{,} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{W}\PYG{o}{.}\PYG{n}{T}\PYG{p}{)} \PYG{o}{+} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{v\PYGZus{}bias}\PYG{p}{)} |
| 46 | + \PYG{n}{h1\PYGZus{}prob} \PYG{o}{=} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{sigmoid}\PYG{p}{(}\PYG{n}{np}\PYG{o}{.}\PYG{n}{dot}\PYG{p}{(}\PYG{n}{v1\PYGZus{}prob}\PYG{p}{,} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{W}\PYG{p}{)} \PYG{o}{+} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{h\PYGZus{}bias}\PYG{p}{)} |
| 47 | + |
| 48 | + \PYG{c+c1}{\PYGZsh{} Weight and bias updates} |
| 49 | + \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{W} \PYG{o}{+=} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{learning\PYGZus{}rate} \PYG{o}{*} \PYG{p}{(}\PYG{n}{np}\PYG{o}{.}\PYG{n}{dot}\PYG{p}{(}\PYG{n}{v0}\PYG{o}{.}\PYG{n}{T}\PYG{p}{,} \PYG{n}{h0\PYGZus{}prob}\PYG{p}{)} \PYG{o}{\PYGZhy{}} \PYG{n}{np}\PYG{o}{.}\PYG{n}{dot}\PYG{p}{(}\PYG{n}{v1\PYGZus{}prob}\PYG{o}{.}\PYG{n}{T}\PYG{p}{,} \PYG{n}{h1\PYGZus{}prob}\PYG{p}{))} \PYG{o}{/} \PYG{n}{batch\PYGZus{}size} |
| 50 | + \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{v\PYGZus{}bias} \PYG{o}{+=} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{learning\PYGZus{}rate} \PYG{o}{*} \PYG{n}{np}\PYG{o}{.}\PYG{n}{mean}\PYG{p}{(}\PYG{n}{v0} \PYG{o}{\PYGZhy{}} \PYG{n}{v1\PYGZus{}prob}\PYG{p}{,} \PYG{n}{axis}\PYG{o}{=}\PYG{l+m+mi}{0}\PYG{p}{)} |
| 51 | + \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{h\PYGZus{}bias} \PYG{o}{+=} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{learning\PYGZus{}rate} \PYG{o}{*} \PYG{n}{np}\PYG{o}{.}\PYG{n}{mean}\PYG{p}{(}\PYG{n}{h0\PYGZus{}prob} \PYG{o}{\PYGZhy{}} \PYG{n}{h1\PYGZus{}prob}\PYG{p}{,} \PYG{n}{axis}\PYG{o}{=}\PYG{l+m+mi}{0}\PYG{p}{)} |
| 52 | + |
| 53 | + \PYG{n}{epoch\PYGZus{}error} \PYG{o}{+=} \PYG{n}{np}\PYG{o}{.}\PYG{n}{mean}\PYG{p}{((}\PYG{n}{v0} \PYG{o}{\PYGZhy{}} \PYG{n}{v1\PYGZus{}prob}\PYG{p}{)} \PYG{o}{**} \PYG{l+m+mi}{2}\PYG{p}{)} |
| 54 | + |
| 55 | + \PYG{n+nb}{print}\PYG{p}{(}\PYG{l+s+sa}{f}\PYG{l+s+s2}{\PYGZdq{}Epoch }\PYG{l+s+si}{\PYGZob{}}\PYG{n}{epoch}\PYG{+w}{ }\PYG{o}{+}\PYG{+w}{ }\PYG{l+m+mi}{1}\PYG{l+s+si}{\PYGZcb{}}\PYG{l+s+s2}{: Reconstruction error = }\PYG{l+s+si}{\PYGZob{}}\PYG{n}{epoch\PYGZus{}error}\PYG{l+s+si}{:}\PYG{l+s+s2}{.4f}\PYG{l+s+si}{\PYGZcb{}}\PYG{l+s+s2}{\PYGZdq{}}\PYG{p}{)} |
| 56 | + |
| 57 | + \PYG{k}{def} \PYG{n+nf}{reconstruct}\PYG{p}{(}\PYG{n+nb+bp}{self}\PYG{p}{,} \PYG{n}{v}\PYG{p}{):} |
| 58 | + \PYG{n}{h} \PYG{o}{=} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{sigmoid}\PYG{p}{(}\PYG{n}{np}\PYG{o}{.}\PYG{n}{dot}\PYG{p}{(}\PYG{n}{v}\PYG{p}{,} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{W}\PYG{p}{)} \PYG{o}{+} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{h\PYGZus{}bias}\PYG{p}{)} |
| 59 | + \PYG{n}{v\PYGZus{}recon} \PYG{o}{=} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{sigmoid}\PYG{p}{(}\PYG{n}{np}\PYG{o}{.}\PYG{n}{dot}\PYG{p}{(}\PYG{n}{h}\PYG{p}{,} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{W}\PYG{o}{.}\PYG{n}{T}\PYG{p}{)} \PYG{o}{+} \PYG{n+nb+bp}{self}\PYG{o}{.}\PYG{n}{v\PYGZus{}bias}\PYG{p}{)} |
| 60 | + \PYG{k}{return} \PYG{n}{v\PYGZus{}recon} |
| 61 | + |
| 62 | +\PYG{c+c1}{\PYGZsh{} Load and split MNIST} |
| 63 | +\PYG{n}{X} \PYG{o}{=} \PYG{n}{load\PYGZus{}binarized\PYGZus{}mnist}\PYG{p}{()} |
| 64 | +\PYG{n}{X\PYGZus{}train}\PYG{p}{,} \PYG{n}{X\PYGZus{}test} \PYG{o}{=} \PYG{n}{train\PYGZus{}test\PYGZus{}split}\PYG{p}{(}\PYG{n}{X}\PYG{p}{,} \PYG{n}{test\PYGZus{}size}\PYG{o}{=}\PYG{l+m+mf}{0.1}\PYG{p}{,} \PYG{n}{random\PYGZus{}state}\PYG{o}{=}\PYG{l+m+mi}{42}\PYG{p}{)} |
| 65 | + |
| 66 | +\PYG{c+c1}{\PYGZsh{} Initialize and train RBM} |
| 67 | +\PYG{n}{rbm} \PYG{o}{=} \PYG{n}{RBM}\PYG{p}{(}\PYG{n}{n\PYGZus{}visible}\PYG{o}{=}\PYG{l+m+mi}{784}\PYG{p}{,} \PYG{n}{n\PYGZus{}hidden}\PYG{o}{=}\PYG{l+m+mi}{128}\PYG{p}{,} \PYG{n}{learning\PYGZus{}rate}\PYG{o}{=}\PYG{l+m+mf}{0.1}\PYG{p}{)} |
| 68 | +\PYG{n}{rbm}\PYG{o}{.}\PYG{n}{train}\PYG{p}{(}\PYG{n}{X\PYGZus{}train}\PYG{p}{,} \PYG{n}{epochs}\PYG{o}{=}\PYG{l+m+mi}{10}\PYG{p}{,} \PYG{n}{batch\PYGZus{}size}\PYG{o}{=}\PYG{l+m+mi}{64}\PYG{p}{)} |
| 69 | + |
| 70 | +\PYG{c+c1}{\PYGZsh{} Visualize reconstruction} |
| 71 | +\PYG{k}{def} \PYG{n+nf}{show\PYGZus{}reconstruction}\PYG{p}{(}\PYG{n}{original}\PYG{p}{,} \PYG{n}{reconstructed}\PYG{p}{):} |
| 72 | + \PYG{n}{fig}\PYG{p}{,} \PYG{n}{axes} \PYG{o}{=} \PYG{n}{plt}\PYG{o}{.}\PYG{n}{subplots}\PYG{p}{(}\PYG{l+m+mi}{1}\PYG{p}{,} \PYG{l+m+mi}{2}\PYG{p}{)} |
| 73 | + \PYG{n}{axes}\PYG{p}{[}\PYG{l+m+mi}{0}\PYG{p}{]}\PYG{o}{.}\PYG{n}{imshow}\PYG{p}{(}\PYG{n}{original}\PYG{o}{.}\PYG{n}{reshape}\PYG{p}{(}\PYG{l+m+mi}{28}\PYG{p}{,} \PYG{l+m+mi}{28}\PYG{p}{),} \PYG{n}{cmap}\PYG{o}{=}\PYG{l+s+s2}{\PYGZdq{}gray\PYGZdq{}}\PYG{p}{)} |
| 74 | + \PYG{n}{axes}\PYG{p}{[}\PYG{l+m+mi}{0}\PYG{p}{]}\PYG{o}{.}\PYG{n}{set\PYGZus{}title}\PYG{p}{(}\PYG{l+s+s2}{\PYGZdq{}Original\PYGZdq{}}\PYG{p}{)} |
| 75 | + \PYG{n}{axes}\PYG{p}{[}\PYG{l+m+mi}{1}\PYG{p}{]}\PYG{o}{.}\PYG{n}{imshow}\PYG{p}{(}\PYG{n}{reconstructed}\PYG{o}{.}\PYG{n}{reshape}\PYG{p}{(}\PYG{l+m+mi}{28}\PYG{p}{,} \PYG{l+m+mi}{28}\PYG{p}{),} \PYG{n}{cmap}\PYG{o}{=}\PYG{l+s+s2}{\PYGZdq{}gray\PYGZdq{}}\PYG{p}{)} |
| 76 | + \PYG{n}{axes}\PYG{p}{[}\PYG{l+m+mi}{1}\PYG{p}{]}\PYG{o}{.}\PYG{n}{set\PYGZus{}title}\PYG{p}{(}\PYG{l+s+s2}{\PYGZdq{}Reconstruction\PYGZdq{}}\PYG{p}{)} |
| 77 | + \PYG{n}{plt}\PYG{o}{.}\PYG{n}{show}\PYG{p}{()} |
| 78 | + |
| 79 | +\PYG{n}{sample} \PYG{o}{=} \PYG{n}{X\PYGZus{}test}\PYG{o}{.}\PYG{n}{iloc}\PYG{p}{[}\PYG{l+m+mi}{0}\PYG{p}{]}\PYG{o}{.}\PYG{n}{values} \PYG{c+c1}{\PYGZsh{} Access the first row and convert to NumPy array} |
| 80 | +\PYG{n}{reconstruction} \PYG{o}{=} \PYG{n}{rbm}\PYG{o}{.}\PYG{n}{reconstruct}\PYG{p}{(}\PYG{n}{sample}\PYG{p}{[}\PYG{n}{np}\PYG{o}{.}\PYG{n}{newaxis}\PYG{p}{,} \PYG{p}{:])} |
| 81 | +\PYG{n}{show\PYGZus{}reconstruction}\PYG{p}{(}\PYG{n}{sample}\PYG{p}{,} \PYG{n}{reconstruction}\PYG{p}{[}\PYG{l+m+mi}{0}\PYG{p}{])} |
| 82 | + |
| 83 | + |
| 84 | +\end{Verbatim} |
0 commit comments