3535import tensorflow as tf
3636
3737
38-
39-
4038flags = tf .flags
4139FLAGS = flags .FLAGS
4240
@@ -50,6 +48,17 @@ def __init__(self, *args, **kwargs):
5048 super (GymDiscreteProblem , self ).__init__ (* args , ** kwargs )
5149 self ._env = None
5250
51+ def example_reading_spec (self , label_repr = None ):
52+
53+ data_fields = {
54+ "inputs" : tf .FixedLenFeature ([210 , 160 , 3 ], tf .int64 ),
55+ "inputs_prev" : tf .FixedLenFeature ([210 , 160 , 3 ], tf .int64 ),
56+ "targets" : tf .FixedLenFeature ([210 , 160 , 3 ], tf .int64 ),
57+ "action" : tf .FixedLenFeature ([1 ], tf .int64 )
58+ }
59+
60+ return data_fields , None
61+
5362 @property
5463 def env_name (self ):
5564 # This is the name of the Gym environment for this problem.
@@ -133,7 +142,7 @@ class GymPongRandom5k(GymDiscreteProblem):
133142
134143 @property
135144 def env_name (self ):
136- return "Pong-v0 "
145+ return "PongNoFrameskip-v4 "
137146
138147 @property
139148 def num_actions (self ):
@@ -148,21 +157,30 @@ def num_steps(self):
148157 return 5000
149158
150159
160+
151161@registry .register_problem
152162class GymPongTrajectoriesFromPolicy (GymDiscreteProblem ):
153163 """Pong game, loaded actions."""
154164
155- def __init__ (self , event_dir , * args , ** kwargs ):
165+ def __init__ (self , * args , ** kwargs ):
156166 super (GymPongTrajectoriesFromPolicy , self ).__init__ (* args , ** kwargs )
157167 self ._env = None
158- self ._event_dir = event_dir
168+ self ._last_policy_op = None
169+ self ._max_frame_pl = None
170+ self ._last_action = self .env .action_space .sample ()
171+ self ._skip = 4
172+ self ._skip_step = 0
173+ self ._obs_buffer = np .zeros ((2 ,) + self .env .observation_space .shape ,
174+ dtype = np .uint8 )
175+
176+ def generator (self , data_dir , tmp_dir ):
159177 env_spec = lambda : atari_wrappers .wrap_atari ( # pylint: disable=g-long-lambda
160178 gym .make ("PongNoFrameskip-v4" ),
161179 warp = False ,
162180 frame_skip = 4 ,
163181 frame_stack = False )
164182 hparams = rl .atari_base ()
165- with tf .variable_scope ("train" ):
183+ with tf .variable_scope ("train" , reuse = tf . AUTO_REUSE ):
166184 policy_lambda = hparams .network
167185 policy_factory = tf .make_template (
168186 "network" ,
@@ -173,14 +191,13 @@ def __init__(self, event_dir, *args, **kwargs):
173191 self ._max_frame_pl , 0 ), 0 ))
174192 policy = actor_critic .policy
175193 self ._last_policy_op = policy .mode ()
176- self ._last_action = self .env .action_space .sample ()
177- self ._skip = 4
178- self ._skip_step = 0
179- self ._obs_buffer = np .zeros ((2 ,) + self .env .observation_space .shape ,
180- dtype = np .uint8 )
181- self ._sess = tf .Session ()
182- model_saver = tf .train .Saver (tf .global_variables (".*network_parameters.*" ))
183- model_saver .restore (self ._sess , FLAGS .model_path )
194+ with tf .Session () as sess :
195+ model_saver = tf .train .Saver (
196+ tf .global_variables (".*network_parameters.*" ))
197+ model_saver .restore (sess , FLAGS .model_path )
198+ for item in super (GymPongTrajectoriesFromPolicy ,
199+ self ).generator (data_dir , tmp_dir ):
200+ yield item
184201
185202 # TODO(blazej0): For training of atari agents wrappers are usually used.
186203 # Below we have a hacky solution which is a workaround to be used together
@@ -191,7 +208,7 @@ def get_action(self, observation=None):
191208 self ._skip_step = (self ._skip_step + 1 ) % self ._skip
192209 if self ._skip_step == 0 :
193210 max_frame = self ._obs_buffer .max (axis = 0 )
194- self ._last_action = int (self . _sess .run (
211+ self ._last_action = int (tf . get_default_session () .run (
195212 self ._last_policy_op ,
196213 feed_dict = {self ._max_frame_pl : max_frame })[0 , 0 ])
197214 return self ._last_action
0 commit comments