|
| 1 | +import numpy as np |
| 2 | +from sklearn.datasets import fetch_openml |
| 3 | +from sklearn.model_selection import train_test_split |
| 4 | +import matplotlib.pyplot as plt |
| 5 | + |
| 6 | +# Load and preprocess MNIST |
| 7 | +def load_binarized_mnist(): |
| 8 | + print("Downloading MNIST...") |
| 9 | + mnist = fetch_openml('mnist_784', version=1) |
| 10 | + X = mnist.data.astype(np.float32) / 255.0 |
| 11 | + X = (X > 0.5).astype(np.float32) # Binarize |
| 12 | + return X |
| 13 | + |
| 14 | +class RBM: |
| 15 | + def __init__(self, n_visible, n_hidden, learning_rate=0.1): |
| 16 | + self.n_visible = n_visible |
| 17 | + self.n_hidden = n_hidden |
| 18 | + self.learning_rate = learning_rate |
| 19 | + |
| 20 | + # Initialize weights and biases |
| 21 | + self.W = np.random.normal(0, 0.01, size=(n_visible, n_hidden)) |
| 22 | + self.v_bias = np.zeros(n_visible) |
| 23 | + self.h_bias = np.zeros(n_hidden) |
| 24 | + |
| 25 | + def sigmoid(self, x): |
| 26 | + return 1 / (1 + np.exp(-x)) |
| 27 | + |
| 28 | + def sample(self, probs): |
| 29 | + return (np.random.rand(*probs.shape) < probs).astype(np.float32) |
| 30 | + |
| 31 | + def train(self, data, epochs=10, batch_size=64): |
| 32 | + n_samples = data.shape[0] |
| 33 | + # Convert the DataFrame to a NumPy array to avoid the KeyError. |
| 34 | + data = data.to_numpy() |
| 35 | + for epoch in range(epochs): |
| 36 | + np.random.shuffle(data) |
| 37 | + epoch_error = 0 |
| 38 | + |
| 39 | + for i in range(0, n_samples, batch_size): |
| 40 | + v0 = data[i:i + batch_size] |
| 41 | + h0_prob = self.sigmoid(np.dot(v0, self.W) + self.h_bias) |
| 42 | + h0_sample = self.sample(h0_prob) |
| 43 | + |
| 44 | + v1_prob = self.sigmoid(np.dot(h0_sample, self.W.T) + self.v_bias) |
| 45 | + h1_prob = self.sigmoid(np.dot(v1_prob, self.W) + self.h_bias) |
| 46 | + |
| 47 | + # Weight and bias updates |
| 48 | + self.W += self.learning_rate * (np.dot(v0.T, h0_prob) - np.dot(v1_prob.T, h1_prob)) / batch_size |
| 49 | + self.v_bias += self.learning_rate * np.mean(v0 - v1_prob, axis=0) |
| 50 | + self.h_bias += self.learning_rate * np.mean(h0_prob - h1_prob, axis=0) |
| 51 | + |
| 52 | + epoch_error += np.mean((v0 - v1_prob) ** 2) |
| 53 | + |
| 54 | + print(f"Epoch {epoch + 1}: Reconstruction error = {epoch_error:.4f}") |
| 55 | + |
| 56 | + def reconstruct(self, v): |
| 57 | + h = self.sigmoid(np.dot(v, self.W) + self.h_bias) |
| 58 | + v_recon = self.sigmoid(np.dot(h, self.W.T) + self.v_bias) |
| 59 | + return v_recon |
| 60 | + |
| 61 | +# Load and split MNIST |
| 62 | +X = load_binarized_mnist() |
| 63 | +X_train, X_test = train_test_split(X, test_size=0.1, random_state=42) |
| 64 | + |
| 65 | +# Initialize and train RBM |
| 66 | +rbm = RBM(n_visible=784, n_hidden=128, learning_rate=0.1) |
| 67 | +rbm.train(X_train, epochs=10, batch_size=64) |
| 68 | + |
| 69 | +# Visualize reconstruction |
| 70 | +def show_reconstruction(original, reconstructed): |
| 71 | + fig, axes = plt.subplots(1, 2) |
| 72 | + axes[0].imshow(original.reshape(28, 28), cmap="gray") |
| 73 | + axes[0].set_title("Original") |
| 74 | + axes[1].imshow(reconstructed.reshape(28, 28), cmap="gray") |
| 75 | + axes[1].set_title("Reconstruction") |
| 76 | + plt.show() |
| 77 | + |
| 78 | +sample = X_test.iloc[0].values # Access the first row and convert to NumPy array |
| 79 | +reconstruction = rbm.reconstruct(sample[np.newaxis, :]) |
| 80 | +show_reconstruction(sample, reconstruction[0]) |
0 commit comments