Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit f0e638e

Browse files
martinpopelrsepassi
authored andcommitted
make save_checkpoints_secs work again (#521)
The functionality was broken during the adoption of TPU trainer_lib.py instead of the original trainer_utils.py. Currently, the default is to save checkpoints each 2000 steps, while in previous T2T versions the default was each 10 minutes.
1 parent afba9dc commit f0e638e

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

tensor2tensor/bin/t2t_trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,17 @@ def create_experiment_fn():
108108

109109

110110
def create_run_config(hp):
111+
save_ckpt_steps = max(FLAGS.iterations_per_loop, FLAGS.local_eval_frequency)
112+
if FLAGS.save_checkpoints_secs:
113+
save_ckpt_steps = None
111114
return trainer_lib.create_run_config(
112115
model_dir=os.path.expanduser(FLAGS.output_dir),
113116
master=FLAGS.master,
114117
iterations_per_loop=FLAGS.iterations_per_loop,
115118
num_shards=FLAGS.tpu_num_shards,
116119
log_device_placement=FLAGS.log_device_placement,
117-
save_checkpoints_steps=max(FLAGS.iterations_per_loop,
118-
FLAGS.local_eval_frequency),
120+
save_checkpoints_steps=save_ckpt_steps,
121+
save_checkpoints_secs=FLAGS.save_checkpoints_secs,
119122
keep_checkpoint_max=FLAGS.keep_checkpoint_max,
120123
keep_checkpoint_every_n_hours=FLAGS.keep_checkpoint_every_n_hours,
121124
num_gpus=FLAGS.worker_gpu,

tensor2tensor/utils/flags.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@
8181
"The default value 10,000 hours effectively disables it.")
8282
flags.DEFINE_integer("save_checkpoints_secs", 0,
8383
"Save checkpoints every this many seconds. "
84-
"Default=0 means let tensorflow.contrib.learn.python.learn"
85-
" decide, which is currently set to 600 = 10 minutes.")
84+
"Default=0 means save checkpoints each x steps where x "
85+
"depends on iterations_per_loop and local_eval_frequency.")
8686
flags.DEFINE_bool("log_device_placement", False,
8787
"Whether to log device placement.")
8888

tensor2tensor/utils/trainer_lib.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def create_run_config(master="",
8787
num_shards=8,
8888
log_device_placement=False,
8989
save_checkpoints_steps=1000,
90+
save_checkpoints_secs=0,
9091
keep_checkpoint_max=20,
9192
keep_checkpoint_every_n_hours=10000,
9293
num_gpus=1,
@@ -121,6 +122,7 @@ def create_run_config(master="",
121122
"session_config": session_config,
122123
"save_summary_steps": 100,
123124
"save_checkpoints_steps": save_checkpoints_steps,
125+
"save_checkpoints_secs": save_checkpoints_secs,
124126
"keep_checkpoint_max": keep_checkpoint_max,
125127
"keep_checkpoint_every_n_hours": keep_checkpoint_every_n_hours,
126128
"tf_random_seed": random_seed,

0 commit comments

Comments
 (0)