@@ -75,15 +75,17 @@ def create_border(video, color="blue", border_percent=2):
7575
7676
7777def convert_videos_to_summaries (input_videos , output_videos , target_videos ,
78- tag , decode_hparams ):
78+ tag , decode_hparams ,
79+ display_ground_truth = False ):
7980 """Converts input, output and target videos into video summaries.
8081
8182 Args:
8283 input_videos: 5-D NumPy array, (NTHWC) conditioning frames.
83- output_videos: 5-D NumPy array, (NTHWC) ground truth .
84+ output_videos: 5-D NumPy array, (NTHWC) model predictions .
8485 target_videos: 5-D NumPy array, (NTHWC) target frames.
8586 tag: tf summary tag.
8687 decode_hparams: tf.contrib.training.HParams.
88+ display_ground_truth: Whether or not to display ground truth videos.
8789 Returns:
8890 summaries: a list of tf frame-by-frame and video summaries.
8991 """
@@ -98,18 +100,20 @@ def convert_videos_to_summaries(input_videos, output_videos, target_videos,
98100 output_videos = create_border (
99101 output_videos , color = "red" , border_percent = border_percent )
100102
101- # Video gif.
102103 all_input = np .concatenate ((input_videos , target_videos ), axis = 1 )
103104 all_output = np .concatenate ((input_videos , output_videos ), axis = 1 )
104- input_summ_vals , _ = common_video .py_gif_summary (
105- "%s/input" % tag , all_input , max_outputs = max_outputs , fps = fps ,
106- return_summary_value = True )
107105 output_summ_vals , _ = common_video .py_gif_summary (
108106 "%s/output" % tag , all_output , max_outputs = max_outputs , fps = fps ,
109107 return_summary_value = True )
110- all_summaries .extend (input_summ_vals )
111108 all_summaries .extend (output_summ_vals )
112109
110+ # Optionally display ground truth.
111+ if display_ground_truth :
112+ input_summ_vals , _ = common_video .py_gif_summary (
113+ "%s/input" % tag , all_input , max_outputs = max_outputs , fps = fps ,
114+ return_summary_value = True )
115+ all_summaries .extend (input_summ_vals )
116+
113117 # Frame-by-frame summaries
114118 iterable = zip (all_input [:max_outputs ], all_output [:max_outputs ])
115119 for ind , (input_video , output_video ) in enumerate (iterable ):
@@ -164,7 +168,8 @@ def display_video_hooks(hook_args):
164168 input_videos = np .asarray (input_videos , dtype = np .uint8 )
165169 summaries = convert_videos_to_summaries (
166170 input_videos , output_videos , target_videos ,
167- tag = "decode_%d" % decode_ind , decode_hparams = hook_args .decode_hparams )
171+ tag = "decode_%d" % decode_ind , decode_hparams = hook_args .decode_hparams ,
172+ display_ground_truth = decode_ind == 0 )
168173 all_summaries .extend (summaries )
169174 return all_summaries
170175
0 commit comments