Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 39fd769

Browse files
Ashish VaswaniRyan Sepassi
authored andcommitted
Add sampling with temperature and cifar10 8 by 8 dataset.
PiperOrigin-RevId: 172031867
1 parent ee922bd commit 39fd769

File tree

3 files changed

+30
-21
lines changed

3 files changed

+30
-21
lines changed

tensor2tensor/data_generators/image.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@
4242
import tensorflow as tf
4343

4444

45+
def resize_by_area(img, size):
46+
"""image resize function used by quite a few image problems."""
47+
return tf.to_int64(
48+
tf.image.resize_images(img, [size, size], tf.image.ResizeMethod.AREA))
49+
50+
4551
class ImageProblem(problem.Problem):
4652

4753
def example_reading_spec(self, label_key=None):
@@ -93,16 +99,12 @@ class ImageCeleba(ImageProblem):
9399

94100
def preprocess_example(self, example, unused_mode, unused_hparams):
95101

96-
def resize(img, size):
97-
return tf.to_int64(
98-
tf.image.resize_images(img, [size, size], tf.image.ResizeMethod.AREA))
99-
100102
inputs = example["inputs"]
101103
# Remove boundaries in CelebA images. Remove 40 pixels each side
102104
# vertically and 20 pixels each side horizontally.
103105
inputs = tf.image.crop_to_bounding_box(inputs, 40, 20, 218 - 80, 178 - 40)
104-
example["inputs"] = resize(inputs, 8)
105-
example["targets"] = resize(inputs, 32)
106+
example["inputs"] = resize_by_area(inputs, 8)
107+
example["targets"] = resize_by_area(inputs, 32)
106108
return example
107109

108110
def hparams(self, defaults, unused_model_hparams):
@@ -388,14 +390,10 @@ def dataset_filename(self):
388390

389391
def preprocess_example(self, example, unused_mode, unused_hparams):
390392

391-
def resize(img, size):
392-
return tf.to_int64(
393-
tf.image.resize_images(img, [size, size], tf.image.ResizeMethod.AREA))
394-
395393
inputs = example["inputs"]
396394
# For Img2Img resize input and output images as desired.
397-
example["inputs"] = resize(inputs, 8)
398-
example["targets"] = resize(inputs, 32)
395+
example["inputs"] = resize_by_area(inputs, 8)
396+
example["targets"] = resize_by_area(inputs, 32)
399397
return example
400398

401399
def hparams(self, defaults, unused_model_hparams):
@@ -654,6 +652,18 @@ def preprocess_example(self, example, mode, unused_hparams):
654652
return example
655653

656654

655+
@registry.register_problem
656+
class ImageCifar10Plain8(ImageCifar10):
657+
"""CIFAR-10 rescaled to 8x8 for output: Conditional image generation."""
658+
659+
def dataset_filename(self):
660+
return "image_cifar10_plain" # Reuse CIFAR-10 plain data.
661+
662+
def preprocess_example(self, example, mode, unused_hparams):
663+
example["inputs"] = resize_by_area(example["inputs"], 8)
664+
return example
665+
666+
657667
@registry.register_problem
658668
class Img2imgCifar10(ImageCifar10):
659669
"""CIFAR-10 rescaled to 8x8 for input and 32x32 for output."""
@@ -663,14 +673,10 @@ def dataset_filename(self):
663673

664674
def preprocess_example(self, example, unused_mode, unused_hparams):
665675

666-
def resize(img, size):
667-
return tf.to_int64(
668-
tf.image.resize_images(img, [size, size], tf.image.ResizeMethod.AREA))
669-
670676
inputs = example["inputs"]
671677
# For Img2Img resize input and output images as desired.
672-
example["inputs"] = resize(inputs, 8)
673-
example["targets"] = resize(inputs, 32)
678+
example["inputs"] = resize_by_area(inputs, 8)
679+
example["targets"] = resize_by_area(inputs, 32)
674680
return example
675681

676682
def hparams(self, defaults, unused_model_hparams):

tensor2tensor/layers/common_hparams.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def basic_params1():
6262
learning_rate_cosine_cycle_steps=250000,
6363
learning_rate=0.1,
6464
sampling_method="argmax", # "argmax" or "random"
65+
sampling_temp=1.0, # temperature for sampling
6566
problem_choice="adaptive", # "uniform", "adaptive", "distributed"
6667
# expand the logits a piece at a time - saves memory.
6768
factored_logits=int(False),

tensor2tensor/utils/t2t_model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -427,15 +427,17 @@ def sample(self, features, last_position_only=False):
427427
else:
428428
assert self._hparams.sampling_method == "random"
429429

430-
def _multinomial_squeeze(logits):
431-
reshaped_logits = tf.reshape(logits, [-1, tf.shape(logits)[-1]])
430+
def _multinomial_squeeze(logits, temperature=1.0):
431+
reshaped_logits = (
432+
tf.reshape(logits, [-1, tf.shape(logits)[-1]])/temperature)
432433
choices = tf.multinomial(reshaped_logits, 1)
433434
choices = tf.reshape(choices,
434435
tf.shape(logits)[:logits.get_shape().ndims - 1])
435436
return choices
436437

437438
sharded_samples = self._data_parallelism(_multinomial_squeeze,
438-
sharded_logits)
439+
sharded_logits,
440+
self._hparams.sampling_temp)
439441
return tf.concat(sharded_samples, 0), sharded_logits, losses
440442

441443
def _shard_features(self, features): # pylint: disable=missing-docstring

0 commit comments

Comments
 (0)