@@ -229,6 +229,76 @@ def loss(self, top_out, targets, weights_fn=common_layers.weights_all):
229229 top_out , targets , weights_fn = weights_fn )
230230
231231
232+ @registry .register_image_modality ("image_identity_compress" )
233+ class ImageIdentityCompressModality (modality .Modality ):
234+ """Modality for images used in generation."""
235+
236+ @property
237+ def top_dimensionality (self ):
238+ return 256
239+
240+ def bottom_compress (self , inputs , name = "bottom" ):
241+ """Transform input from data space to model space.
242+
243+ Perform conversion of RGB pixel values to a real number and combine values
244+ for each pixel to form representation of image_length x image_length dims.
245+
246+ Args:
247+ inputs: A Tensor with shape [batch, ...]
248+ name: string, scope.
249+ Returns:
250+ body_input: A Tensor with shape [batch, ?, ?, body_input_depth].
251+ """
252+ with tf .variable_scope (name ):
253+ inputs = common_layers .convert_rgb_to_real (inputs )
254+ ishape = tf .shape (inputs )
255+ inputs = tf .reshape (inputs , [- 1 , ishape [1 ], ishape [2 ]* ishape [3 ], 1 ])
256+ inputs .set_shape ([None , None , None , 1 ])
257+ # We compress RGB intensities for each pixel using a conv.
258+ x = common_layers .conv_block (
259+ inputs ,
260+ self ._body_input_depth , [((1 , 1 ), (1 , 3 ))],
261+ first_relu = False ,
262+ padding = "VALID" ,
263+ strides = (1 , 3 ),
264+ force2d = True ,
265+ name = "conv_input" )
266+ return x
267+
268+ def bottom (self , inputs ):
269+ return self .bottom_compress (inputs , "input_bottom" )
270+
271+ def targets_bottom (self , inputs ):
272+ return self .bottom_compress (inputs , "output_bottom" )
273+
274+ def top (self , body_output , _ ):
275+ with tf .variable_scope (self .name ):
276+ hidden_dim = self ._model_hparams .hidden_size
277+ img_len = self ._model_hparams .img_len
278+ channels = self ._model_hparams .num_channels
279+ batch = tf .shape (body_output )[0 ]
280+ x = common_layers .conv (
281+ body_output ,
282+ hidden_dim * channels , (1 , 1 ),
283+ padding = "VALID" ,
284+ activation = tf .nn .relu ,
285+ name = "decompress_conv" )
286+ x = tf .reshape (x , [batch , img_len , img_len * channels , hidden_dim ])
287+ x .set_shape ([None , None , None , hidden_dim ])
288+ x = common_layers .conv (x ,
289+ self .top_dimensionality ,
290+ (1 , 1 ), name = "output_conv" )
291+ x = tf .reshape (x , [- 1 , img_len , img_len ,
292+ channels , self .top_dimensionality ])
293+ return x
294+
295+ def loss (self , top_out , targets , weights_fn = common_layers .weights_all ):
296+ # Call the default implementation, but weight 1.0 on 0s by default.
297+ # (Since we're processing images and so have no padding and some pixel 0s.)
298+ return super (ImageIdentityCompressModality , self ).loss (
299+ top_out , targets , weights_fn = weights_fn )
300+
301+
232302@registry .register_audio_modality ("default" )
233303class AudioModality (modality .Modality ):
234304 """Performs strided conv compressions for audio data."""
0 commit comments