@@ -383,13 +383,6 @@ def dataset(self,
383383 # Construct the Problem's hparams so that items within it are accessible
384384 _ = self .get_hparams (hparams )
385385
386- data_fields , data_items_to_decoders = self .example_reading_spec ()
387- if data_items_to_decoders is None :
388- data_items_to_decoders = {
389- field : tf .contrib .slim .tfexample_decoder .Tensor (field )
390- for field in data_fields
391- }
392-
393386 is_training = mode == tf .estimator .ModeKeys .TRAIN
394387 data_filepattern = self .filepattern (data_dir , dataset_split , shard = shard )
395388 tf .logging .info ("Reading data files from %s" , data_filepattern )
@@ -406,22 +399,13 @@ def dataset(self,
406399 else :
407400 dataset = tf .data .TFRecordDataset (data_files )
408401
409- def decode_record (record ):
410- """Serialized Example to dict of <feature name, Tensor>."""
411- decoder = tf .contrib .slim .tfexample_decoder .TFExampleDecoder (
412- data_fields , data_items_to_decoders )
413-
414- decode_items = list (data_items_to_decoders )
415- decoded = decoder .decode (record , items = decode_items )
416- return dict (zip (decode_items , decoded ))
417-
418402 def _preprocess (example ):
419403 example = self .preprocess_example (example , mode , hparams )
420404 self .maybe_reverse_features (example )
421405 self .maybe_copy_features (example )
422406 return example
423407
424- dataset = dataset .map (decode_record , num_parallel_calls = num_threads )
408+ dataset = dataset .map (self . decode_example , num_parallel_calls = num_threads )
425409
426410 if preprocess :
427411 dataset = dataset .map (_preprocess , num_parallel_calls = num_threads )
@@ -430,6 +414,22 @@ def _preprocess(example):
430414
431415 return dataset
432416
417+ def decode_example (self , serialized_example ):
418+ """Return a dict of Tensors from a serialized tensorflow.Example."""
419+ data_fields , data_items_to_decoders = self .example_reading_spec ()
420+ if data_items_to_decoders is None :
421+ data_items_to_decoders = {
422+ field : tf .contrib .slim .tfexample_decoder .Tensor (field )
423+ for field in data_fields
424+ }
425+
426+ decoder = tf .contrib .slim .tfexample_decoder .TFExampleDecoder (
427+ data_fields , data_items_to_decoders )
428+
429+ decode_items = list (data_items_to_decoders )
430+ decoded = decoder .decode (serialized_example , items = decode_items )
431+ return dict (zip (decode_items , decoded ))
432+
433433 @property
434434 def has_inputs (self ):
435435 return "inputs" in self .get_feature_encoders ()
@@ -496,7 +496,8 @@ def input_fn(self, mode, hparams, params=None, config=None,
496496 mode: tf.estimator.ModeKeys
497497 hparams: HParams, model hparams
498498 params: dict, may include "batch_size"
499- config: RunConfig; if passed, should include t2t_device_info dict
499+ config: RunConfig; should have the data_parallelism attribute if not using
500+ TPU
500501 dataset_kwargs: dict, if passed, will pass as kwargs to self.dataset
501502 method when called
502503
@@ -521,29 +522,8 @@ def gpu_valid_size(example):
521522 hparams .max_length if drop_long_sequences else 10 ** 9 )
522523
523524 def define_shapes (example ):
524- """Set the right shapes for the features."""
525- inputs = example ["inputs" ]
526- targets = example ["targets" ]
527-
528- # Ensure inputs and targets are proper rank.
529- while len (inputs .get_shape ()) < 4 :
530- inputs = tf .expand_dims (inputs , axis = - 1 )
531- while len (targets .get_shape ()) < 4 :
532- targets = tf .expand_dims (targets , axis = - 1 )
533-
534- example ["inputs" ] = inputs
535- example ["targets" ] = targets
536-
537- if config .use_tpu :
538- # Ensure batch size is set on all features
539- for _ , t in six .iteritems (example ):
540- shape = t .get_shape ().as_list ()
541- shape [0 ] = params ["batch_size" ]
542- t .set_shape (t .get_shape ().merge_with (shape ))
543- # Assert shapes are fully known
544- t .get_shape ().assert_is_fully_defined ()
545-
546- return example
525+ return _standardize_shapes (
526+ example , batch_size = (config .use_tpu and params ["batch_size" ]))
547527
548528 # Read and preprocess
549529 data_dir = hparams .data_dir
@@ -569,7 +549,7 @@ def define_shapes(example):
569549 dataset = dataset .apply (
570550 tf .contrib .data .batch_and_drop_remainder (tpu_batch_size ))
571551 else :
572- num_shards = config .t2t_device_info [ "num_shards" ]
552+ num_shards = config .data_parallelism . n
573553 dataset = dataset .batch (hparams .batch_size * num_shards )
574554 else :
575555 # Variable length features
@@ -586,7 +566,7 @@ def define_shapes(example):
586566 dataset = dataset .filter (gpu_valid_size )
587567 batching_scheme = data_reader .hparams_to_batching_scheme (
588568 hparams ,
589- shard_multiplier = config .t2t_device_info [ "num_shards" ] ,
569+ shard_multiplier = config .data_parallelism . n ,
590570 length_multiplier = self .get_hparams ().batch_size_multiplier )
591571 if hparams .use_fixed_batch_size :
592572 batching_scheme ["batch_sizes" ] = [hparams .batch_size ]
@@ -601,7 +581,7 @@ def define_shapes(example):
601581 dataset = dataset .prefetch (1 )
602582 features = dataset .make_one_shot_iterator ().get_next ()
603583 if not config .use_tpu :
604- _summarize_features (features , config .t2t_device_info [ "num_shards" ] )
584+ _summarize_features (features , config .data_parallelism . n )
605585
606586 if mode == tf .estimator .ModeKeys .PREDICT :
607587 features ["infer_targets" ] = features ["targets" ]
@@ -614,6 +594,25 @@ def define_shapes(example):
614594
615595 return features , features ["targets" ]
616596
597+ def serving_input_fn (self , hparams ):
598+ """Input fn for serving export, starting from serialized example."""
599+ mode = tf .estimator .ModeKeys .PREDICT
600+ serialized_example = tf .placeholder (
601+ dtype = tf .string , shape = [None ], name = "serialized_example" )
602+ dataset = tf .data .Dataset .from_tensor_slices (serialized_example )
603+ dataset = dataset .map (self .decode_example )
604+ dataset = dataset .map (lambda ex : self .preprocess_example (ex , mode , hparams ))
605+ dataset = dataset .map (data_reader .cast_int64_to_int32 )
606+ dataset = dataset .padded_batch (1000 , dataset .output_shapes )
607+ dataset = dataset .map (_standardize_shapes )
608+ features = tf .contrib .data .get_single_element (dataset )
609+
610+ if self .has_inputs :
611+ features .pop ("targets" , None )
612+
613+ return tf .estimator .export .ServingInputReceiver (
614+ features = features , receiver_tensors = serialized_example )
615+
617616
618617class FeatureInfo (object ):
619618
@@ -907,3 +906,28 @@ def _summarize_features(features, num_shards=1):
907906 tf .summary .scalar ("%s_nonpadding_tokens" % k , nonpadding_tokens )
908907 tf .summary .scalar ("%s_nonpadding_fraction" % k ,
909908 tf .reduce_mean (nonpadding ))
909+
910+
911+ def _standardize_shapes (features , batch_size = None ):
912+ """Set the right shapes for the features."""
913+
914+ for fname in ["inputs" , "targets" ]:
915+ if fname not in features :
916+ continue
917+
918+ f = features [fname ]
919+ while len (f .get_shape ()) < 4 :
920+ f = tf .expand_dims (f , axis = - 1 )
921+
922+ features [fname ] = f
923+
924+ if batch_size :
925+ # Ensure batch size is set on all features
926+ for _ , t in six .iteritems (features ):
927+ shape = t .get_shape ().as_list ()
928+ shape [0 ] = batch_size
929+ t .set_shape (t .get_shape ().merge_with (shape ))
930+ # Assert shapes are fully known
931+ t .get_shape ().assert_is_fully_defined ()
932+
933+ return features
0 commit comments