Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit ee922bd

Browse files
author
Ryan Sepassi
committed
Remove added var scopes in @recompute_grad and @fn_with_custom_grad
PiperOrigin-RevId: 172016510
1 parent dc190ec commit ee922bd

File tree

2 files changed

+9
-10
lines changed

2 files changed

+9
-10
lines changed

tensor2tensor/layers/common_layers.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1943,13 +1943,13 @@ def _fn_with_custom_grad(fn, inputs, grad_fn, use_global_vars=False):
19431943
Returns:
19441944
fn(*inputs)
19451945
"""
1946-
with tf.variable_scope(None, default_name="fn_with_custom_grad") as vs:
1947-
inputs = list(inputs)
1948-
outputs = fn(*inputs)
1949-
if use_global_vars:
1950-
train_vars = list(vs.global_variables())
1951-
else:
1952-
train_vars = list(vs.trainable_variables())
1946+
vs = tf.get_variable_scope()
1947+
get_vars_fn = (vs.global_variables if use_global_vars else
1948+
vs.trainable_variables)
1949+
len_before_vars = len(get_vars_fn())
1950+
inputs = list(inputs)
1951+
outputs = fn(*inputs)
1952+
train_vars = get_vars_fn()[len_before_vars:]
19531953

19541954
if grad_fn is None:
19551955
return outputs

tensor2tensor/layers/rev_block.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,8 +365,7 @@ def grad_fn(inputs, variables, outputs, output_grads):
365365

366366
@common_layers.fn_with_custom_grad(grad_fn)
367367
def fn_with_recompute(*args):
368-
with tf.variable_scope(None, default_name="recompute") as vs:
369-
cached_vs.append(vs)
370-
return fn(*args)
368+
cached_vs.append(tf.get_variable_scope())
369+
return fn(*args)
371370

372371
return fn_with_recompute(*args)

0 commit comments

Comments
 (0)