Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 1c98b8e

Browse files
authored
Merge pull request #566 from deepsense-ai/rl_init
Initial commit of reinforcement learning module.
2 parents 103d057 + 3707499 commit 1c98b8e

File tree

13 files changed

+930
-0
lines changed

13 files changed

+930
-0
lines changed

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
'future',
3737
'gevent',
3838
'gunicorn',
39+
'gym<=0.9.5', # gym in version 0.9.6 has some temporary issues.
40+
'munch',
3941
'numpy',
4042
'requests',
4143
'scipy',

tensor2tensor/bin/t2t-rl-trainer

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/usr/bin/env python
2+
"""t2t-rl-trainer."""
3+
from __future__ import absolute_import
4+
from __future__ import division
5+
from __future__ import print_function
6+
7+
from tensor2tensor.bin import t2t_rl_trainer
8+
9+
import tensorflow as tf
10+
11+
def main(argv):
12+
t2t_rl_trainer.main(argv)
13+
14+
15+
if __name__ == "__main__":
16+
tf.app.run()
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# coding=utf-8
2+
# Copyright 2018 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Training of RL agent with PPO algorithm."""
17+
18+
from __future__ import absolute_import
19+
20+
import functools
21+
from munch import Munch
22+
import tensorflow as tf
23+
24+
from tensor2tensor.rl.collect import define_collect
25+
from tensor2tensor.rl.envs.utils import define_batch_env
26+
from tensor2tensor.rl.ppo import define_ppo_epoch
27+
28+
29+
def define_train(policy_lambda, env_lambda, config):
30+
env = env_lambda()
31+
action_space = env.action_space
32+
observation_space = env.observation_space
33+
34+
batch_env = define_batch_env(env_lambda, config["num_agents"])
35+
36+
policy_factory = tf.make_template(
37+
'network',
38+
functools.partial(policy_lambda, observation_space,
39+
action_space, config))
40+
41+
(collect_op, memory) = define_collect(policy_factory, batch_env, config)
42+
43+
with tf.control_dependencies([collect_op]):
44+
ppo_op = define_ppo_epoch(memory, policy_factory, config)
45+
46+
return ppo_op
47+
48+
49+
def main():
50+
train(example_params())
51+
52+
53+
def train(params):
54+
policy_lambda, env_lambda, config = params
55+
ppo_op = define_train(policy_lambda, env_lambda, config)
56+
57+
with tf.Session() as sess:
58+
sess.run(tf.global_variables_initializer())
59+
for _ in range(config.epochs_num):
60+
sess.run(ppo_op)
61+
62+
63+
def example_params():
64+
from tensor2tensor.rl import networks
65+
config = {}
66+
config['init_mean_factor'] = 0.1
67+
config['init_logstd'] = 0.1
68+
config['policy_layers'] = 100, 100
69+
config['value_layers'] = 100, 100
70+
config['num_agents'] = 30
71+
config['clipping_coef'] = 0.2
72+
config['gae_gamma'] = 0.99
73+
config['gae_lambda'] = 0.95
74+
config['entropy_loss_coef'] = 0.01
75+
config['value_loss_coef'] = 1
76+
config['optimizer'] = tf.train.AdamOptimizer
77+
config['learning_rate'] = 1e-4
78+
config['optimization_epochs'] = 15
79+
config['epoch_length'] = 200
80+
config['epochs_num'] = 2000
81+
82+
config = Munch(config)
83+
return networks.feed_forward_gaussian_fun, pendulum_lambda, config
84+
85+
86+
def pendulum_lambda():
87+
import gym
88+
return gym.make("Pendulum-v0")
89+
90+
91+
if __name__ == '__main__':
92+
main()

tensor2tensor/rl/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Tensor2Tensor Reinforcement Learning starter.
2+
3+
The rl package intention is to provide possiblity to run reinforcement
4+
algorithms within Tensorflow's computation graph.
5+
6+
Currently the only supported algorithm is Proximy Policy Optimization - PPO.
7+
8+
## Sample usage - training in Pendulum-v0 environment.
9+
10+
```t2t-rl-trainer```

tensor2tensor/rl/__init__.py

Whitespace-only changes.

tensor2tensor/rl/collect.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# coding=utf-8
2+
# Copyright 2018 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Collect trajectories from interactions of agent with environment."""
17+
18+
import tensorflow as tf
19+
20+
21+
def define_collect(policy_factory, batch_env, config):
22+
23+
memory_shape = [config.epoch_length] + [batch_env.observ.shape.as_list()[0]]
24+
memories_shapes_and_types = [
25+
# observation
26+
(memory_shape + [batch_env.observ.shape.as_list()[1]], tf.float32),
27+
(memory_shape, tf.float32), # reward
28+
(memory_shape, tf.bool), # done
29+
(memory_shape + batch_env.action_shape, tf.float32), # action
30+
(memory_shape, tf.float32), # pdf
31+
(memory_shape, tf.float32), # value function
32+
]
33+
memory = [tf.Variable(tf.zeros(shape, dtype), trainable=False)
34+
for (shape, dtype) in memories_shapes_and_types]
35+
cumulative_rewards = tf.Variable(
36+
tf.zeros(config.num_agents, tf.float32), trainable=False)
37+
38+
should_reset_var = tf.Variable(True, trainable=False)
39+
reset_op = tf.cond(should_reset_var,
40+
lambda: batch_env.reset(tf.range(config.num_agents)),
41+
lambda: 0.0)
42+
with tf.control_dependencies([reset_op]):
43+
reset_once_op = tf.assign(should_reset_var, False)
44+
45+
with tf.control_dependencies([reset_once_op]):
46+
47+
def step(index, scores_sum, scores_num):
48+
# Note - the only way to ensure making a copy of tensor is to run simple
49+
# operation. We are waiting for tf.copy:
50+
# https://github.com/tensorflow/tensorflow/issues/11186
51+
obs_copy = batch_env.observ + 0
52+
actor_critic = policy_factory(tf.expand_dims(obs_copy, 0))
53+
policy = actor_critic.policy
54+
action = policy.sample()
55+
postprocessed_action = actor_critic.action_postprocessing(action)
56+
simulate_output = batch_env.simulate(postprocessed_action[0, ...])
57+
pdf = policy.prob(action)[0]
58+
with tf.control_dependencies(simulate_output):
59+
reward, done = simulate_output
60+
done = tf.reshape(done, (config.num_agents,))
61+
to_save = [obs_copy, reward, done, action[0, ...], pdf,
62+
actor_critic.value[0]]
63+
save_ops = [tf.scatter_update(memory_slot, index, value)
64+
for memory_slot, value in zip(memory, to_save)]
65+
cumulate_rewards_op = cumulative_rewards.assign_add(reward)
66+
agent_indicies_to_reset = tf.where(done)[:, 0]
67+
with tf.control_dependencies([cumulate_rewards_op]):
68+
scores_sum_delta = tf.reduce_sum(
69+
tf.gather(cumulative_rewards, agent_indicies_to_reset))
70+
scores_num_delta = tf.count_nonzero(done, dtype=tf.int32)
71+
with tf.control_dependencies(save_ops + [scores_sum_delta,
72+
scores_num_delta]):
73+
reset_env_op = batch_env.reset(agent_indicies_to_reset)
74+
reset_cumulative_rewards_op = tf.scatter_update(
75+
cumulative_rewards, agent_indicies_to_reset,
76+
tf.zeros(tf.shape(agent_indicies_to_reset)))
77+
with tf.control_dependencies([reset_env_op,
78+
reset_cumulative_rewards_op]):
79+
return [index + 1, scores_sum + scores_sum_delta,
80+
scores_num + scores_num_delta]
81+
82+
init = [tf.constant(0), tf.constant(0.0), tf.constant(0)]
83+
index, scores_sum, scores_num = tf.while_loop(
84+
lambda c, _1, _2: c < config.epoch_length,
85+
step,
86+
init,
87+
parallel_iterations=1,
88+
back_prop=False)
89+
mean_score = tf.cond(tf.greater(scores_num, 0),
90+
lambda: scores_sum / tf.cast(scores_num, tf.float32),
91+
lambda: 0.)
92+
printing = tf.Print(0, [mean_score, scores_sum, scores_num], "mean_score: ")
93+
with tf.control_dependencies([printing]):
94+
return tf.identity(index), memory

tensor2tensor/rl/envs/__init__.py

Whitespace-only changes.

tensor2tensor/rl/envs/batch_env.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# coding=utf-8
2+
# Copyright 2018 The Tensor2Tensor Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
# The code was based on Danijar Hafner's code from tf.agents:
17+
# https://github.com/tensorflow/agents/blob/master/agents/tools/batch_env.py
18+
19+
"""Combine multiple environments to step them in batch."""
20+
21+
from __future__ import absolute_import
22+
from __future__ import division
23+
from __future__ import print_function
24+
25+
import numpy as np
26+
27+
28+
class BatchEnv(object):
29+
"""Combine multiple environments to step them in batch."""
30+
31+
def __init__(self, envs, blocking):
32+
"""Combine multiple environments to step them in batch.
33+
34+
To step environments in parallel, environments must support a
35+
`blocking=False` argument to their step and reset functions that makes them
36+
return callables instead to receive the result at a later time.
37+
38+
Args:
39+
envs: List of environments.
40+
blocking: Step environments after another rather than in parallel.
41+
42+
Raises:
43+
ValueError: Environments have different observation or action spaces.
44+
"""
45+
self._envs = envs
46+
self._blocking = blocking
47+
observ_space = self._envs[0].observation_space
48+
if not all(env.observation_space == observ_space for env in self._envs):
49+
raise ValueError('All environments must use the same observation space.')
50+
action_space = self._envs[0].action_space
51+
if not all(env.action_space == action_space for env in self._envs):
52+
raise ValueError('All environments must use the same observation space.')
53+
54+
def __len__(self):
55+
"""Number of combined environments."""
56+
return len(self._envs)
57+
58+
def __getitem__(self, index):
59+
"""Access an underlying environment by index."""
60+
return self._envs[index]
61+
62+
def __getattr__(self, name):
63+
"""Forward unimplemented attributes to one of the original environments.
64+
65+
Args:
66+
name: Attribute that was accessed.
67+
68+
Returns:
69+
Value behind the attribute name one of the wrapped environments.
70+
"""
71+
return getattr(self._envs[0], name)
72+
73+
def step(self, actions):
74+
"""Forward a batch of actions to the wrapped environments.
75+
76+
Args:
77+
actions: Batched action to apply to the environment.
78+
79+
Raises:
80+
ValueError: Invalid actions.
81+
82+
Returns:
83+
Batch of observations, rewards, and done flags.
84+
"""
85+
for index, (env, action) in enumerate(zip(self._envs, actions)):
86+
if not env.action_space.contains(action):
87+
message = 'Invalid action at index {}: {}'
88+
raise ValueError(message.format(index, action))
89+
if self._blocking:
90+
transitions = [
91+
env.step(action)
92+
for env, action in zip(self._envs, actions)]
93+
else:
94+
transitions = [
95+
env.step(action, blocking=False)
96+
for env, action in zip(self._envs, actions)]
97+
transitions = [transition() for transition in transitions]
98+
observs, rewards, dones, infos = zip(*transitions)
99+
observ = np.stack(observs).astype(np.float32)
100+
reward = np.stack(rewards).astype(np.float32)
101+
done = np.stack(dones)
102+
info = tuple(infos)
103+
return observ, reward, done, info
104+
105+
def reset(self, indices=None):
106+
"""Reset the environment and convert the resulting observation.
107+
108+
Args:
109+
indices: The batch indices of environments to reset; defaults to all.
110+
111+
Returns:
112+
Batch of observations.
113+
"""
114+
if indices is None:
115+
indices = np.arange(len(self._envs))
116+
if self._blocking:
117+
observs = [self._envs[index].reset() for index in indices]
118+
else:
119+
observs = [self._envs[index].reset(blocking=False) for index in indices]
120+
observs = [observ() for observ in observs]
121+
observ = np.stack(observs)
122+
observ = observ.astype(np.float32)
123+
return observ
124+
125+
def close(self):
126+
"""Send close messages to the external process and join them."""
127+
for env in self._envs:
128+
if hasattr(env, 'close'):
129+
env.close()

0 commit comments

Comments
 (0)