|
| 1 | +''' |
| 2 | + Import the KMNIST dataset |
| 3 | + Source: https://arxiv.org/abs/1812.01718 |
| 4 | + Description: Japanese character MNIST dataset. |
| 5 | +
|
| 6 | + ~~~ Important note ~~~ |
| 7 | + Please cite the following paper when using or referencing the dataset: |
| 8 | + Clanuwat, T., Bober-Irizar, M., Kitamoto, A., Lamb, A., Yamamoto, K., & Ha, D. (2018). Deep learning for classical Japanese literature. arXiv preprint arXiv:1812.01718. Retrieved from https://arxiv.org/abs/1812.01718 |
| 9 | +
|
| 10 | +''' |
| 11 | + |
| 12 | +from keras.utils.data_utils import get_file |
| 13 | +import numpy as np |
| 14 | + |
| 15 | +def load_data(path='kmnist.npz', type='kmnist'): |
| 16 | + """Loads the KMNIST dataset. |
| 17 | + # Arguments |
| 18 | + path: path where to cache the dataset locally |
| 19 | + (relative to ~/.keras/datasets). |
| 20 | + type: any of kmnist, k49 |
| 21 | + # Returns |
| 22 | + Tuple of Numpy arrays: `(input_train, target_train), (input_test, target_test)`. |
| 23 | + """ |
| 24 | + # Load training images |
| 25 | + path_train = get_file( |
| 26 | + f'{path}_{type}_train_imgs', |
| 27 | + origin=f'http://codh.rois.ac.jp/kmnist/dataset/{type}/{type}-train-imgs.npz' |
| 28 | + ) |
| 29 | + input_train = np.load(path_train)['arr_0'] |
| 30 | + |
| 31 | + # Load training labels |
| 32 | + path_train_labels = get_file( |
| 33 | + f'{path}_{type}_train_labels', |
| 34 | + origin=f'http://codh.rois.ac.jp/kmnist/dataset/{type}/{type}-train-labels.npz' |
| 35 | + ) |
| 36 | + target_train = np.load(path_train_labels)['arr_0'] |
| 37 | + |
| 38 | + # Load testing images |
| 39 | + path_test = get_file( |
| 40 | + f'{path}_{type}_test_imgs', |
| 41 | + origin=f'http://codh.rois.ac.jp/kmnist/dataset/{type}/{type}-test-imgs.npz' |
| 42 | + ) |
| 43 | + input_test = np.load(path_test)['arr_0'] |
| 44 | + |
| 45 | + # Load testing labels |
| 46 | + path_test_labels = get_file( |
| 47 | + f'{path}_{type}_test_labels', |
| 48 | + origin=f'http://codh.rois.ac.jp/kmnist/dataset/{type}/{type}-test-labels.npz' |
| 49 | + ) |
| 50 | + target_test = np.load(path_test_labels)['arr_0'] |
| 51 | + |
| 52 | + # Return data |
| 53 | + return (input_train, target_train), (input_test, target_test) |
0 commit comments