Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 0d250b3

Browse files
noersepassi
authored andcommitted
Support --t2t_usr_dir also in t2t-datagen (#160)
* Refactor user directory loading functionality and use it also from t2t-datagen * Move flag declaration to the binary files
1 parent c91989c commit 0d250b3

File tree

3 files changed

+46
-18
lines changed

3 files changed

+46
-18
lines changed

tensor2tensor/bin/t2t-datagen

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ from tensor2tensor.data_generators import wiki
4848
from tensor2tensor.data_generators import wmt
4949
from tensor2tensor.data_generators import wsj_parsing
5050
from tensor2tensor.utils import registry
51+
from tensor2tensor.utils import usr_dir
5152

5253
import tensorflow as tf
5354

@@ -64,6 +65,13 @@ flags.DEFINE_integer("max_cases", 0,
6465
"Maximum number of cases to generate (unbounded if 0).")
6566
flags.DEFINE_integer("random_seed", 429459, "Random seed to use.")
6667

68+
flags.DEFINE_string("t2t_usr_dir", "",
69+
"Path to a Python module that will be imported. The "
70+
"__init__.py file should include the necessary imports. "
71+
"The imported files should contain registrations, "
72+
"e.g. @registry.register_model calls, that will then be "
73+
"available to the t2t-datagen.")
74+
6775
# Mapping from problems that we can generate data for to their generators.
6876
# pylint: disable=g-long-lambda
6977
_SUPPORTED_PROBLEM_GENERATORS = {
@@ -273,6 +281,7 @@ def set_random_seed():
273281

274282
def main(_):
275283
tf.logging.set_verbosity(tf.logging.INFO)
284+
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
276285

277286
# Calculate the list of problems to generate.
278287
problems = sorted(

tensor2tensor/bin/t2t-trainer

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import sys
3636
# Dependency imports
3737

3838
from tensor2tensor.utils import trainer_utils as utils
39-
39+
from tensor2tensor.utils import usr_dir
4040
import tensorflow as tf
4141

4242
flags = tf.flags
@@ -49,25 +49,9 @@ flags.DEFINE_string("t2t_usr_dir", "",
4949
"e.g. @registry.register_model calls, that will then be "
5050
"available to the t2t-trainer.")
5151

52-
53-
def import_usr_dir():
54-
"""Import module at FLAGS.t2t_usr_dir, if provided."""
55-
if not FLAGS.t2t_usr_dir:
56-
return
57-
dir_path = os.path.expanduser(FLAGS.t2t_usr_dir)
58-
if dir_path[-1] == "/":
59-
dir_path = dir_path[:-1]
60-
containing_dir, module_name = os.path.split(dir_path)
61-
tf.logging.info("Importing user module %s from path %s", module_name,
62-
containing_dir)
63-
sys.path.insert(0, containing_dir)
64-
importlib.import_module(module_name)
65-
sys.path.pop(0)
66-
67-
6852
def main(_):
6953
tf.logging.set_verbosity(tf.logging.INFO)
70-
import_usr_dir()
54+
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
7155
utils.log_registry()
7256
utils.validate_flags()
7357
utils.run(

tensor2tensor/utils/usr_dir.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2017 The Tensor2Tensor Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utility to load code from an external directory supplied by user."""
16+
17+
import os
18+
import sys
19+
import importlib
20+
import tensorflow as tf
21+
22+
23+
def import_usr_dir(usr_dir):
24+
"""Import user module, if provided."""
25+
if not usr_dir:
26+
return
27+
dir_path = os.path.expanduser(usr_dir)
28+
if dir_path[-1] == "/":
29+
dir_path = dir_path[:-1]
30+
containing_dir, module_name = os.path.split(dir_path)
31+
tf.logging.info("Importing user module %s from path %s", module_name,
32+
containing_dir)
33+
sys.path.insert(0, containing_dir)
34+
importlib.import_module(module_name)
35+
sys.path.pop(0)

0 commit comments

Comments
 (0)