diff --git a/tensorflow/python/training/adagrad.py b/tensorflow/python/training/adagrad.py index 46eec12748a..80a2e91f24a 100644 --- a/tensorflow/python/training/adagrad.py +++ b/tensorflow/python/training/adagrad.py @@ -97,6 +97,9 @@ def _prepare(self): learning_rate = self._call_if_callable(self._learning_rate) self._learning_rate_tensor = ops.convert_to_tensor( learning_rate, name="learning_rate") + global_step_var = training_util.get_or_create_global_step() + with ops.colocate_with(self._learning_rate_tensor): + self._global_step_on_worker = array_ops.identity(global_step_var) + 1 def _apply_dense(self, grad, var): acc = self.get_slot(var, "accumulator") @@ -139,14 +142,13 @@ def _hash_table_apply_sparse(self, grad, var, indices): def _resource_apply_sparse(self, grad, var, indices): acc = self.get_slot(var, "accumulator") if isinstance(var, kv_variable_ops.EmbeddingVariable): - global_step = training_util.get_or_create_global_step() return training_ops.kv_resource_sparse_apply_adagrad( var.handle, acc.handle, math_ops.cast(self._learning_rate_tensor, grad.dtype), grad, indices, - global_step, + self._global_step_on_worker, use_locking=self._use_locking) else: return training_ops.resource_sparse_apply_adagrad(