Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit 35789ec

Browse files
T2T Teamcopybara-github
authored andcommitted
Allow for the export of extra outputs from Infer().
PiperOrigin-RevId: 302541178
1 parent 1c965c4 commit 35789ec

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

tensor2tensor/utils/decoding.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@ def decode_hparams(overrides=""):
9797
# Used for MLPerf compliance logging.
9898
mlperf_decode_step=0.0,
9999
mlperf_threshold=25.0,
100-
mlperf_success=False)
100+
mlperf_success=False,
101+
# A comma-delimited list of additional infer() outputs to be exported.
102+
export_extra_infer_outputs="")
101103
hp.parse(overrides)
102104
return hp
103105

tensor2tensor/utils/t2t_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1744,6 +1744,10 @@ def estimator_spec_predict(self, features, use_tpu=False):
17441744
if "scores" in predictions:
17451745
export_out["scores"] = predictions["scores"]
17461746

1747+
if decode_hparams.get("export_extra_infer_outputs"):
1748+
for output in decode_hparams.export_extra_infer_outputs.split(","):
1749+
export_out[output] = infer_out[output]
1750+
17471751
# Necessary to rejoin examples in the correct order with the Cloud ML Engine
17481752
# batch prediction API.
17491753
if "batch_prediction_key" in predictions:

0 commit comments

Comments
 (0)