Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 60dd5e0

Browse files
author
Ryan Sepassi
committed
Add tests for genetics problems
PiperOrigin-RevId: 162569505
1 parent 84445cc commit 60dd5e0

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2017 The Tensor2Tensor Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for Genetics problems."""
16+
from __future__ import absolute_import
17+
from __future__ import division
18+
from __future__ import print_function
19+
20+
# Dependency imports
21+
22+
import numpy as np
23+
24+
from tensor2tensor.data_generators import genetics
25+
26+
import tensorflow as tf
27+
28+
29+
class GeneticsTest(tf.test.TestCase):
30+
31+
def _oneHotBases(self, bases):
32+
one_hots = []
33+
for base_id in bases:
34+
one_hot = [False] * 4
35+
if base_id < 4:
36+
one_hot[base_id] = True
37+
one_hots.append(one_hot)
38+
return np.array(one_hots)
39+
40+
def testRecordToExample(self):
41+
inputs = self._oneHotBases([0, 1, 3, 4, 1, 0])
42+
mask = np.array([True, False, True])
43+
outputs = np.array([[1.0, 2.0, 3.0], [5.0, 1.0, 0.2], [5.1, 2.3, 2.3]])
44+
ex_dict = genetics.to_example_dict(inputs, mask, outputs)
45+
46+
self.assertAllEqual([2, 3, 5, 6, 3, 2, 1], ex_dict["inputs"])
47+
self.assertAllEqual([1.0, 0.0, 1.0], ex_dict["targets_mask"])
48+
self.assertAllEqual([1.0, 2.0, 3.0, 5.0, 1.0, 0.2, 5.1, 2.3, 2.3],
49+
ex_dict["targets"])
50+
self.assertAllEqual([3, 3], ex_dict["targets_shape"])
51+
52+
def testGenerateShardArgs(self):
53+
num_examples = 37
54+
num_shards = 4
55+
outfiles = [str(i) for i in range(num_shards)]
56+
shard_args = genetics.generate_shard_args(outfiles, num_examples)
57+
58+
starts, ends, fnames = zip(*shard_args)
59+
self.assertAllEqual([0, 9, 18, 27], starts)
60+
self.assertAllEqual([9, 18, 27, 37], ends)
61+
self.assertAllEqual(fnames, outfiles)
62+
63+
64+
if __name__ == "__main__":
65+
tf.test.main()

0 commit comments

Comments
 (0)