Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit b448041

Browse files
Dustin TranCopybara-Service
authored andcommitted
Don't export functions when saving hparams to json.
PiperOrigin-RevId: 237176017
1 parent 1b08988 commit b448041

File tree

3 files changed

+23
-6
lines changed

3 files changed

+23
-6
lines changed

tensor2tensor/utils/hparam.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -542,15 +542,15 @@ def to_json(self, indent=None, separators=None, sort_keys=False):
542542
A JSON string.
543543
"""
544544
def remove_callables(x):
545-
if callable(x):
546-
return x.__name__
545+
"""Omit callable elements from input with arbitrary nesting."""
547546
if isinstance(x, dict):
548-
return {k: remove_callables(v) for k, v in six.iteritems(x)}
549-
if isinstance(x, list):
550-
return [remove_callables(i) for i in x]
547+
return {k: remove_callables(v) for k, v in six.iteritems(x)
548+
if not callable(v)}
549+
elif isinstance(x, list):
550+
return [remove_callables(i) for i in x if not callable(i)]
551551
return x
552552
return json.dumps(
553-
{k: remove_callables(v) for k, v in six.iteritems(self.values())},
553+
remove_callables(self.values()),
554554
indent=indent,
555555
separators=separators,
556556
sort_keys=sort_keys)

tensor2tensor/utils/hparam_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,15 @@ def testSetFromMap(self):
136136
self.assertDictEqual({'d': [0.1, 0.2, 0.3], 'x': 1, 'b': 2.0},
137137
hparams.values())
138138

139+
def testFunction(self):
140+
def f(x):
141+
return x
142+
hparams = hparam.HParams(function=f)
143+
self.assertEqual(hparams.function, f)
144+
145+
json_str = hparams.to_json()
146+
self.assertEqual(json_str, '{}')
147+
139148
def testBoolParsing(self):
140149
for value in 'true', 'false', 'True', 'False', '1', '0':
141150
for initial in False, True:

tensor2tensor/utils/hparams_lib.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ def create_hparams_from_json(json_path, hparams=None):
6464
tf.logging.info("Loading hparams from existing json %s" % json_path)
6565
with tf.gfile.Open(json_path, "r") as f:
6666
hparams_values = json.load(f)
67+
# Prevent certain keys from overwriting the passed-in hparams.
68+
# TODO(trandustin): Remove this hack after registries are available to avoid
69+
# saving them as functions.
70+
hparams_values.pop("bottom", None)
71+
hparams_values.pop("loss", None)
72+
hparams_values.pop("name", None)
73+
hparams_values.pop("top", None)
74+
hparams_values.pop("weights_fn", None)
6775
new_hparams = HParams(**hparams_values)
6876
# Some keys are in new_hparams but not hparams, so we need to be more
6977
# careful than simply using parse_json() from HParams

0 commit comments

Comments
 (0)