Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 8548dab

Browse files
author
Ryan Sepassi
committed
Support TF 1.5 launching an ML Engine job (which only supports 1.4)
PiperOrigin-RevId: 185954947
1 parent 5654701 commit 8548dab

File tree

5 files changed

+121
-6
lines changed

5 files changed

+121
-6
lines changed

docs/new_problem.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class PoetryLines(text_problems.Text2TextProblem):
5959
# generate_data will shard the data into TRAIN and EVAL for us.
6060
return False
6161

62+
@property
6263
def dataset_splits(self):
6364
"""Splits of data to produce and number of output shards for each."""
6465
# 10% evaluation data
@@ -141,6 +142,7 @@ training data will be generated into 90 files and the evaluation data into 10.
141142
# generate_data will shard the data into TRAIN and EVAL for us.
142143
return False
143144

145+
@property
144146
def dataset_splits(self):
145147
"""Splits of data to produce and number of output shards for each."""
146148
# 10% evaluation data

tensor2tensor/bin/t2t_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,8 +322,8 @@ def main(argv):
322322
if FLAGS.generate_data:
323323
generate_data()
324324

325-
if hasattr(FLAGS, "job_dir") and FLAGS.job_dir:
326-
FLAGS.output_dir = FLAGS.job_dir
325+
if cloud_mlengine.job_dir():
326+
FLAGS.output_dir = cloud_mlengine.job_dir()
327327

328328
if argv:
329329
set_hparams_from_args(argv[1:])

tensor2tensor/test_data/example_usr_dir/my_submodule.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,17 @@
1414
# limitations under the License.
1515

1616
"""Example registrations for T2T."""
17+
import re
18+
19+
from tensor2tensor.data_generators import problem
20+
from tensor2tensor.data_generators import text_problems
1721
from tensor2tensor.layers import common_hparams
1822
from tensor2tensor.utils import registry
1923

24+
# Use register_model for a new T2TModel
25+
# Use register_problem for a new Problem
26+
# Use register_hparams for a new hyperparameter set
27+
2028

2129
@registry.register_hparams
2230
def my_very_own_hparams():
@@ -28,5 +36,64 @@ def my_very_own_hparams():
2836
hp.add_hparam("filter_size", 2048)
2937
return hp
3038

31-
# Use register_model for a new T2TModel
32-
# Use register_problem for a new Problem
39+
40+
@registry.register_problem
41+
class PoetryLines(text_problems.Text2TextProblem):
42+
"""Predict next line of poetry from the last line. From Gutenberg texts."""
43+
44+
@property
45+
def approx_vocab_size(self):
46+
return 2**13 # ~8k
47+
48+
@property
49+
def is_generate_per_split(self):
50+
# generate_data will shard the data into TRAIN and EVAL for us.
51+
return False
52+
53+
@property
54+
def dataset_splits(self):
55+
"""Splits of data to produce and number of output shards for each."""
56+
# 10% evaluation data
57+
return [{
58+
"split": problem.DatasetSplit.TRAIN,
59+
"shards": 90,
60+
}, {
61+
"split": problem.DatasetSplit.EVAL,
62+
"shards": 10,
63+
}]
64+
65+
def generate_samples(self, data_dir, tmp_dir, dataset_split):
66+
del data_dir
67+
del tmp_dir
68+
del dataset_split
69+
70+
# pylint: disable=g-import-not-at-top
71+
from gutenberg import acquire
72+
from gutenberg import cleanup
73+
# pylint: enable=g-import-not-at-top
74+
75+
books = [
76+
# bookid, skip N lines
77+
(19221, 223),
78+
(15553, 522),
79+
]
80+
81+
for (book_id, toskip) in books:
82+
text = cleanup.strip_headers(acquire.load_etext(book_id)).strip()
83+
lines = text.split("\n")[toskip:]
84+
prev_line = None
85+
ex_count = 0
86+
for line in lines:
87+
# Any line that is all upper case is a title or author name
88+
if not line or line.upper() == line:
89+
prev_line = None
90+
continue
91+
92+
line = re.sub("[^a-z]+", " ", line.strip().lower())
93+
if prev_line and line:
94+
yield {
95+
"inputs": prev_line,
96+
"targets": line,
97+
}
98+
ex_count += 1
99+
prev_line = line
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# coding=utf-8
2+
# Copyright 2018 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Example setup.py for a t2t_usr_dir launching on Cloud ML Engine.
17+
18+
This is only necessary if you have additional required pip packages for the
19+
import of your usr_dir, and only if you're launching t2t-trainer on Cloud ML
20+
Engine with the --cloud_mlengine flag.
21+
22+
Note that the call to setup uses find_packages() and that the location of this
23+
file is alongside the __init__.py file that imports my_submodule.
24+
"""
25+
from setuptools import find_packages
26+
from setuptools import setup
27+
setup(
28+
name='DummyUsrDirPackage',
29+
version='0.1',
30+
packages=find_packages(),
31+
install_requires=[
32+
'gutenberg',
33+
],
34+
)

tensor2tensor/utils/cloud_mlengine.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,24 @@
4747
"""
4848

4949

50+
def job_dir():
51+
# The flag --job-dir is parsed differently before and after switching to absl
52+
return getattr(FLAGS, 'job-dir', '') or getattr(FLAGS, 'job_dir', '')
53+
54+
5055
def flags_as_args():
5156
"""Convert FLAGS to list of args suitable for passing on cmd line."""
52-
args_dict = dict(FLAGS.__dict__['__flags'])
57+
if hasattr(FLAGS, 'flag_values_dict'):
58+
args_dict = FLAGS.flag_values_dict()
59+
else:
60+
args_dict = dict(FLAGS.__dict__['__flags'])
5361
del args_dict['cloud_mlengine']
5462
# Configured later
5563
del args_dict['t2t_usr_dir']
64+
args_dict.pop('h', None)
65+
args_dict.pop('helpfull', None)
66+
args_dict.pop('helpshort', None)
67+
args_dict.pop('help', None)
5668
args = []
5769
for name, val in args_dict.items():
5870
if val is None:
@@ -223,7 +235,7 @@ def configure_usr_dir(job_spec, usr_tar):
223235
def launch():
224236
"""Launch t2t_trainer on Cloud ML Engine."""
225237
assert not FLAGS.cloud_tpu
226-
assert not FLAGS.job_dir
238+
assert not job_dir()
227239
assert FLAGS.output_dir.startswith('gs://')
228240
assert FLAGS.data_dir.startswith('gs://')
229241
assert FLAGS.worker_replicas <= 1

0 commit comments

Comments
 (0)