Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 2239ad0

Browse files
T2T TeamRyan Sepassi
authored andcommitted
Add CIFAR-100 dataset and generator
PiperOrigin-RevId: 186042711
1 parent c471341 commit 2239ad0

File tree

1 file changed

+241
-16
lines changed

1 file changed

+241
-16
lines changed

tensor2tensor/data_generators/cifar.py

Lines changed: 241 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,26 @@
4343
"data_batch_5"
4444
]
4545
_CIFAR10_TEST_FILES = ["test_batch"]
46-
_CIFAR10_IMAGE_SIZE = 32
46+
_CIFAR10_IMAGE_SIZE = _CIFAR100_IMAGE_SIZE = 32
4747

48+
_CIFAR100_URL = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
49+
_CIFAR100_PREFIX = "cifar-100-python/"
50+
_CIFAR100_TRAIN_FILES = ["train"]
51+
_CIFAR100_TEST_FILES = ["test"]
4852

49-
def _get_cifar10(directory):
53+
54+
def _get_cifar(directory, url):
5055
"""Download and extract CIFAR to directory unless it is there."""
51-
filename = os.path.basename(_CIFAR10_URL)
52-
path = generator_utils.maybe_download(directory, filename, _CIFAR10_URL)
56+
filename = os.path.basename(url)
57+
path = generator_utils.maybe_download(directory, filename, url)
5358
tarfile.open(path, "r:gz").extractall(directory)
5459

5560

56-
def cifar10_generator(tmp_dir, training, how_many, start_from=0):
57-
"""Image generator for CIFAR-10.
61+
def cifar_generator(cifar_version, tmp_dir, training, how_many, start_from=0):
62+
"""Image generator for CIFAR-10 and 100.
5863
5964
Args:
65+
cifar_version: string; one of "cifar10" or "cifar100"
6066
tmp_dir: path to temporary storage directory.
6167
training: a Boolean; if true, we use the train set, otherwise the test set.
6268
how_many: how many images and labels to generate.
@@ -65,21 +71,33 @@ def cifar10_generator(tmp_dir, training, how_many, start_from=0):
6571
Returns:
6672
An instance of image_generator that produces CIFAR-10 images and labels.
6773
"""
68-
_get_cifar10(tmp_dir)
69-
data_files = _CIFAR10_TRAIN_FILES if training else _CIFAR10_TEST_FILES
74+
if cifar_version == "cifar10":
75+
url = _CIFAR10_URL
76+
train_files = _CIFAR10_TRAIN_FILES
77+
test_files = _CIFAR10_TEST_FILES
78+
prefix = _CIFAR10_PREFIX
79+
image_size = _CIFAR10_IMAGE_SIZE
80+
elif cifar_version == "cifar100":
81+
url = _CIFAR100_URL
82+
train_files = _CIFAR100_TRAIN_FILES
83+
test_files = _CIFAR100_TEST_FILES
84+
prefix = _CIFAR100_PREFIX
85+
image_size = _CIFAR100_IMAGE_SIZE
86+
87+
_get_cifar(tmp_dir, url)
88+
data_files = train_files if training else test_files
7089
all_images, all_labels = [], []
7190
for filename in data_files:
72-
path = os.path.join(tmp_dir, _CIFAR10_PREFIX, filename)
91+
path = os.path.join(tmp_dir, prefix, filename)
7392
with tf.gfile.Open(path, "r") as f:
7493
data = cPickle.load(f)
7594
images = data["data"]
7695
num_images = images.shape[0]
77-
images = images.reshape((num_images, 3, _CIFAR10_IMAGE_SIZE,
78-
_CIFAR10_IMAGE_SIZE))
96+
images = images.reshape((num_images, 3, image_size, image_size))
7997
all_images.extend([
8098
np.squeeze(images[j]).transpose((1, 2, 0)) for j in xrange(num_images)
8199
])
82-
labels = data["labels"]
100+
labels = data["labels" if cifar_version == "cifar10" else "fine_labels"]
83101
all_labels.extend([labels[j] for j in xrange(num_images)])
84102
return image_utils.image_generator(
85103
all_images[start_from:start_from + how_many],
@@ -112,19 +130,19 @@ def preprocess_example(self, example, mode, unused_hparams):
112130

113131
def generator(self, data_dir, tmp_dir, is_training):
114132
if is_training:
115-
return cifar10_generator(tmp_dir, True, 48000)
133+
return cifar_generator("cifar10", tmp_dir, True, 48000)
116134
else:
117-
return cifar10_generator(tmp_dir, True, 2000, 48000)
135+
return cifar_generator("cifar10", tmp_dir, True, 2000, 48000)
118136

119137

120138
@registry.register_problem
121139
class ImageCifar10(ImageCifar10Tune):
122140

123141
def generator(self, data_dir, tmp_dir, is_training):
124142
if is_training:
125-
return cifar10_generator(tmp_dir, True, 50000)
143+
return cifar_generator("cifar10", tmp_dir, True, 50000)
126144
else:
127-
return cifar10_generator(tmp_dir, False, 10000)
145+
return cifar_generator("cifar10", tmp_dir, False, 10000)
128146

129147

130148
@registry.register_problem
@@ -188,3 +206,210 @@ def hparams(self, defaults, unused_model_hparams):
188206
p.batch_size_multiplier = 256
189207
p.input_space_id = 1
190208
p.target_space_id = 1
209+
210+
211+
@registry.register_problem
212+
class ImageCifar100Tune(mnist.ImageMnistTune):
213+
"""Cifar-100 Tune."""
214+
215+
@property
216+
def num_classes(self):
217+
return 100
218+
219+
@property
220+
def num_channels(self):
221+
return 3
222+
223+
@property
224+
def class_labels(self):
225+
return [
226+
"beaver",
227+
"dolphin",
228+
"otter",
229+
"seal",
230+
"whale",
231+
"aquarium fish",
232+
"flatfish",
233+
"ray",
234+
"shark",
235+
"trout",
236+
"orchids",
237+
"poppies",
238+
"roses",
239+
"sunflowers",
240+
"tulips",
241+
"bottles",
242+
"bowls",
243+
"cans",
244+
"cups",
245+
"plates",
246+
"apples",
247+
"mushrooms",
248+
"oranges",
249+
"pears",
250+
"sweet peppers",
251+
"clock",
252+
"computer keyboard",
253+
"lamp",
254+
"telephone",
255+
"television",
256+
"bed",
257+
"chair",
258+
"couch",
259+
"table",
260+
"wardrobe",
261+
"bee",
262+
"beetle",
263+
"butterfly",
264+
"caterpillar",
265+
"cockroach",
266+
"bear",
267+
"leopard",
268+
"lion",
269+
"tiger",
270+
"wolf",
271+
"bridge",
272+
"castle",
273+
"house",
274+
"road",
275+
"skyscraper",
276+
"cloud",
277+
"forest",
278+
"mountain",
279+
"plain",
280+
"sea",
281+
"camel",
282+
"cattle",
283+
"chimpanzee",
284+
"elephant",
285+
"kangaroo",
286+
"fox",
287+
"porcupine",
288+
"possum",
289+
"raccoon",
290+
"skunk",
291+
"crab",
292+
"lobster",
293+
"snail",
294+
"spider",
295+
"worm",
296+
"baby",
297+
"boy",
298+
"girl",
299+
"man",
300+
"woman",
301+
"crocodile",
302+
"dinosaur",
303+
"lizard",
304+
"snake",
305+
"turtle",
306+
"hamster",
307+
"mouse",
308+
"rabbit",
309+
"shrew",
310+
"squirrel",
311+
"maple",
312+
"oak",
313+
"palm",
314+
"pine",
315+
"willow",
316+
"bicycle",
317+
"bus",
318+
"motorcycle",
319+
"pickup truck",
320+
"train",
321+
"lawn-mower",
322+
"rocket",
323+
"streetcar",
324+
"tank",
325+
"tractor",
326+
]
327+
328+
def preprocess_example(self, example, mode, unused_hparams):
329+
image = example["inputs"]
330+
image.set_shape([_CIFAR100_IMAGE_SIZE, _CIFAR100_IMAGE_SIZE, 3])
331+
if mode == tf.estimator.ModeKeys.TRAIN:
332+
image = image_utils.cifar_image_augmentation(image)
333+
image = tf.image.per_image_standardization(image)
334+
example["inputs"] = image
335+
return example
336+
337+
def generator(self, data_dir, tmp_dir, is_training):
338+
if is_training:
339+
return cifar_generator("cifar100", tmp_dir, True, 48000)
340+
else:
341+
return cifar_generator("cifar100", tmp_dir, True, 2000, 48000)
342+
343+
344+
@registry.register_problem
345+
class ImageCifar100(ImageCifar100Tune):
346+
347+
def generator(self, data_dir, tmp_dir, is_training):
348+
if is_training:
349+
return cifar_generator("cifar100", tmp_dir, True, 50000)
350+
else:
351+
return cifar_generator("cifar100", tmp_dir, False, 10000)
352+
353+
354+
@registry.register_problem
355+
class ImageCifar100Plain(ImageCifar100):
356+
357+
def preprocess_example(self, example, mode, unused_hparams):
358+
image = example["inputs"]
359+
image.set_shape([_CIFAR100_IMAGE_SIZE, _CIFAR100_IMAGE_SIZE, 3])
360+
image = tf.image.per_image_standardization(image)
361+
example["inputs"] = image
362+
return example
363+
364+
365+
@registry.register_problem
366+
class ImageCifar100PlainGen(ImageCifar100Plain):
367+
"""CIFAR-100 32x32 for image generation without standardization preprep."""
368+
369+
def dataset_filename(self):
370+
return "image_cifar100_plain" # Reuse CIFAR-100 plain data.
371+
372+
def preprocess_example(self, example, mode, unused_hparams):
373+
example["inputs"].set_shape([_CIFAR100_IMAGE_SIZE, _CIFAR100_IMAGE_SIZE, 3])
374+
example["inputs"] = tf.to_int64(example["inputs"])
375+
return example
376+
377+
378+
@registry.register_problem
379+
class ImageCifar100Plain8(ImageCifar100):
380+
"""CIFAR-100 rescaled to 8x8 for output: Conditional image generation."""
381+
382+
def dataset_filename(self):
383+
return "image_cifar100_plain" # Reuse CIFAR-100 plain data.
384+
385+
def preprocess_example(self, example, mode, unused_hparams):
386+
image = example["inputs"]
387+
image = image_utils.resize_by_area(image, 8)
388+
image = tf.image.per_image_standardization(image)
389+
example["inputs"] = image
390+
return example
391+
392+
393+
@registry.register_problem
394+
class Img2imgCifar100(ImageCifar100):
395+
"""CIFAR-100 rescaled to 8x8 for input and 32x32 for output."""
396+
397+
def dataset_filename(self):
398+
return "image_cifar100_plain" # Reuse CIFAR-100 plain data.
399+
400+
def preprocess_example(self, example, unused_mode, unused_hparams):
401+
402+
inputs = example["inputs"]
403+
# For Img2Img resize input and output images as desired.
404+
example["inputs"] = image_utils.resize_by_area(inputs, 8)
405+
example["targets"] = image_utils.resize_by_area(inputs, 32)
406+
return example
407+
408+
def hparams(self, defaults, unused_model_hparams):
409+
p = defaults
410+
p.input_modality = {"inputs": ("image:identity", 256)}
411+
p.target_modality = ("image:identity", 256)
412+
p.batch_size_multiplier = 256
413+
p.max_expected_batch_size_per_shard = 4
414+
p.input_space_id = 1
415+
p.target_space_id = 1

0 commit comments

Comments
 (0)