2525
2626import numpy as np
2727
28- from tensor2tensor .trax import inputs
28+ from tensor2tensor .trax import inputs as inputs_lib
2929from tensor2tensor .trax import models
3030from tensor2tensor .trax import trax
3131
@@ -43,7 +43,7 @@ def input_stream():
4343 yield (np .random .rand (* ([batch_size ] + list (input_shape ))),
4444 np .random .randint (num_classes , size = batch_size ))
4545
46- return inputs .Inputs (
46+ return inputs_lib .Inputs (
4747 train_stream = input_stream ,
4848 eval_stream = input_stream ,
4949 input_shape = input_shape )
@@ -57,32 +57,37 @@ def tmp_dir(self):
5757 yield tmp
5858 gfile .rmtree (tmp )
5959
60- @property
61- def train_args (self ):
62- num_classes = 4
63- return dict (
64- model = functools .partial (models .MLP ,
60+ def test_train_eval_predict (self ):
61+ with self .tmp_dir () as output_dir :
62+ # Prepare model and inputs
63+ num_classes = 4
64+ train_steps = 2
65+ eval_steps = 2
66+ model = functools .partial (models .MLP ,
6567 hidden_size = 16 ,
66- num_output_classes = num_classes ),
67- inputs = lambda : test_inputs (num_classes ),
68- train_steps = 3 ,
69- eval_steps = 2 )
68+ num_output_classes = num_classes )
69+ inputs = lambda : test_inputs (num_classes )
7070
71- def _test_train (self , train_args ):
72- with self .tmp_dir () as output_dir :
73- state = trax .train (output_dir , ** train_args )
71+ # Train and evaluate
72+ state = trax .train (output_dir ,
73+ model = model ,
74+ inputs = inputs ,
75+ train_steps = train_steps ,
76+ eval_steps = eval_steps )
7477
7578 # Assert total train steps
76- self .assertEqual (train_args [ " train_steps" ] , state .step )
79+ self .assertEqual (train_steps , state .step )
7780
78- # Assert 2 epochs ran
81+ # Assert 2 evaluations ran
7982 train_acc = state .history .get ("train" , "metrics/accuracy" )
8083 eval_acc = state .history .get ("eval" , "metrics/accuracy" )
8184 self .assertEqual (len (train_acc ), len (eval_acc ))
8285 self .assertEqual (2 , len (eval_acc ))
8386
84- def test_train (self ):
85- self ._test_train (self .train_args )
87+ # Predict with final params
88+ _ , predict_fun = model ()
89+ inputs = inputs ().train_stream ()
90+ predict_fun (state .params , next (inputs )[0 ])
8691
8792
8893if __name__ == "__main__" :
0 commit comments