4242import tensorflow as tf
4343
4444
45+ def resize_by_area (img , size ):
46+ """image resize function used by quite a few image problems."""
47+ return tf .to_int64 (
48+ tf .image .resize_images (img , [size , size ], tf .image .ResizeMethod .AREA ))
49+
50+
4551class ImageProblem (problem .Problem ):
4652
4753 def example_reading_spec (self , label_key = None ):
@@ -93,16 +99,12 @@ class ImageCeleba(ImageProblem):
9399
94100 def preprocess_example (self , example , unused_mode , unused_hparams ):
95101
96- def resize (img , size ):
97- return tf .to_int64 (
98- tf .image .resize_images (img , [size , size ], tf .image .ResizeMethod .AREA ))
99-
100102 inputs = example ["inputs" ]
101103 # Remove boundaries in CelebA images. Remove 40 pixels each side
102104 # vertically and 20 pixels each side horizontally.
103105 inputs = tf .image .crop_to_bounding_box (inputs , 40 , 20 , 218 - 80 , 178 - 40 )
104- example ["inputs" ] = resize (inputs , 8 )
105- example ["targets" ] = resize (inputs , 32 )
106+ example ["inputs" ] = resize_by_area (inputs , 8 )
107+ example ["targets" ] = resize_by_area (inputs , 32 )
106108 return example
107109
108110 def hparams (self , defaults , unused_model_hparams ):
@@ -388,14 +390,10 @@ def dataset_filename(self):
388390
389391 def preprocess_example (self , example , unused_mode , unused_hparams ):
390392
391- def resize (img , size ):
392- return tf .to_int64 (
393- tf .image .resize_images (img , [size , size ], tf .image .ResizeMethod .AREA ))
394-
395393 inputs = example ["inputs" ]
396394 # For Img2Img resize input and output images as desired.
397- example ["inputs" ] = resize (inputs , 8 )
398- example ["targets" ] = resize (inputs , 32 )
395+ example ["inputs" ] = resize_by_area (inputs , 8 )
396+ example ["targets" ] = resize_by_area (inputs , 32 )
399397 return example
400398
401399 def hparams (self , defaults , unused_model_hparams ):
@@ -654,6 +652,18 @@ def preprocess_example(self, example, mode, unused_hparams):
654652 return example
655653
656654
655+ @registry .register_problem
656+ class ImageCifar10Plain8 (ImageCifar10 ):
657+ """CIFAR-10 rescaled to 8x8 for output: Conditional image generation."""
658+
659+ def dataset_filename (self ):
660+ return "image_cifar10_plain" # Reuse CIFAR-10 plain data.
661+
662+ def preprocess_example (self , example , mode , unused_hparams ):
663+ example ["inputs" ] = resize_by_area (example ["inputs" ], 8 )
664+ return example
665+
666+
657667@registry .register_problem
658668class Img2imgCifar10 (ImageCifar10 ):
659669 """CIFAR-10 rescaled to 8x8 for input and 32x32 for output."""
@@ -663,14 +673,10 @@ def dataset_filename(self):
663673
664674 def preprocess_example (self , example , unused_mode , unused_hparams ):
665675
666- def resize (img , size ):
667- return tf .to_int64 (
668- tf .image .resize_images (img , [size , size ], tf .image .ResizeMethod .AREA ))
669-
670676 inputs = example ["inputs" ]
671677 # For Img2Img resize input and output images as desired.
672- example ["inputs" ] = resize (inputs , 8 )
673- example ["targets" ] = resize (inputs , 32 )
678+ example ["inputs" ] = resize_by_area (inputs , 8 )
679+ example ["targets" ] = resize_by_area (inputs , 32 )
674680 return example
675681
676682 def hparams (self , defaults , unused_model_hparams ):
0 commit comments