From 1c5120141b2043c4e0721774c183cb01d23b0682 Mon Sep 17 00:00:00 2001 From: David Soergel <soergel@google.com> Date: Fri, 13 Jan 2017 12:14:41 -0800 Subject: [PATCH] Fix SavedModel export when predictions is a single tensor and output_alternatives not given Change: 144470271 --- .../python/learn/estimators/prediction_key.py | 1 + .../learn/utils/saved_model_export_utils.py | 2 ++ .../learn/utils/saved_model_export_utils_test.py | 15 +++++++++++++++ 3 files changed, 18 insertions(+) diff --git a/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py b/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py index 8a6d0ef0183..7dc26781f94 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py +++ b/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py @@ -25,3 +25,4 @@ class PredictionKey(object): LOGISTIC = "logistic" SCORES = "scores" TOP_K = "top_k" + GENERIC = "output" diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index c386b2adf90..c3fdd3086c4 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -167,6 +167,8 @@ def get_output_alternatives( # interpret the model as single-headed of unknown type. default_problem_type = constants.ProblemType.UNSPECIFIED default_outputs = model_fn_ops.predictions + if not isinstance(default_outputs, dict): + default_outputs = {prediction_key.PredictionKey.GENERIC: default_outputs} actual_default_output_alternative_key = DEFAULT_OUTPUT_ALTERNATIVE_KEY output_alternatives = {actual_default_output_alternative_key: (default_problem_type, default_outputs)} diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py index 2ce28f8648e..23d171b58bd 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py @@ -124,6 +124,21 @@ class SavedModelExportUtilsTest(test.TestCase): }) }, output_alternatives) + def test_get_output_alternatives_implicit_single(self): + prediction_tensor = constant_op.constant(["bogus"]) + model_fn_ops = model_fn.ModelFnOps( + model_fn.ModeKeys.INFER, + predictions=prediction_tensor, + output_alternatives=None) + + output_alternatives, _ = saved_model_export_utils.get_output_alternatives( + model_fn_ops) + self.assertEqual({ + "default_output_alternative": (constants.ProblemType.UNSPECIFIED, { + "output": prediction_tensor + }) + }, output_alternatives) + def test_build_all_signature_defs(self): input_features = constant_op.constant(["10"]) input_example = constant_op.constant(["11"])