|
18 | 18 | from __future__ import absolute_import |
19 | 19 | from __future__ import division |
20 | 20 | from __future__ import print_function |
21 | | - |
22 | 21 | # Dependency imports |
23 | | - |
24 | | -from six.moves import xrange # pylint: disable=redefined-builtin |
25 | | - |
26 | 22 | from tensor2tensor.layers import common_layers |
27 | 23 | from tensor2tensor.models import transformer |
28 | 24 | from tensor2tensor.utils import expert_utils |
29 | 25 | from tensor2tensor.utils import registry |
30 | 26 | from tensor2tensor.utils import t2t_model |
31 | | - |
32 | 27 | import tensorflow as tf |
33 | 28 |
|
34 | 29 |
|
@@ -207,7 +202,7 @@ def embed(x): |
207 | 202 | shape=[hparams.v_size, hparams.hidden_size]) |
208 | 203 | h1 = tf.gather(means, x) |
209 | 204 | elif hparams.bottleneck_kind == "rounding": |
210 | | - h1 = tf.round(x) |
| 205 | + h1 = x |
211 | 206 |
|
212 | 207 | h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2") |
213 | 208 | return tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin") |
@@ -255,9 +250,19 @@ def embed(x): |
255 | 250 | x_means_hot, x_means, l = kmeans(x, means, hparams, name="vq-vae-kmeans") |
256 | 251 | h1 = tf.stop_gradient(x_means) + x - tf.stop_gradient(x) |
257 | 252 | c = tf.argmax(x_means_hot, axis=-1) |
258 | | - if hparams.bottleneck_kind == "round": |
259 | | - c = tf.round(x) |
260 | | - h1 = x + tf.stop_gradient(tf.round(x) - x) |
| 253 | + if hparams.bottleneck_kind == "rounding": |
| 254 | + h = tf.layers.dense(x, 1, name="vcc") |
| 255 | + |
| 256 | + # Make h between 0 and 1 |
| 257 | + h = tf.sigmoid(h) |
| 258 | + |
| 259 | + # Multiply by z_size to get it between [0, z_size] |
| 260 | + h *= hparams.v_size |
| 261 | + |
| 262 | + # Use the rounding bottleneck |
| 263 | + h1 = h + tf.stop_gradient(tf.round(h) - h) |
| 264 | + c = tf.squeeze(tf.round(h), axis=-1) |
| 265 | + c = tf.to_int32(c) |
261 | 266 | h2 = tf.layers.dense(tf.nn.relu(h1), filter_size, name="vch2") |
262 | 267 | res = tf.layers.dense(tf.nn.relu(h2), hparams.hidden_size, name="vcfin") |
263 | 268 | return res, c, l, embed |
|
0 commit comments