From 7bf8ccdb4ef5b0b28c1cf0d5084e07ffbf0e2703 Mon Sep 17 00:00:00 2001 From: Yutaka Leon Date: Wed, 24 Jan 2018 15:45:55 -0800 Subject: [PATCH] Set export_outs in KMeans' EstimatorSpec. PiperOrigin-RevId: 183154542 --- .../contrib/factorization/python/ops/kmeans.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index 9a5413fc3f2..4d0f9b24240 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -25,6 +25,7 @@ import time from tensorflow.contrib.factorization.python.ops import clustering_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -32,6 +33,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary import summary from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util @@ -207,6 +209,15 @@ class _ModelFn(object): training_hooks.append( _LossRelativeChangeHook(loss, self._relative_tolerance)) + export_outputs = { + KMeansClustering.ALL_DISTANCES: + export_output.PredictOutput(all_distances[0]), + KMeansClustering.CLUSTER_INDEX: + export_output.PredictOutput(model_predictions[0]), + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + export_output.PredictOutput(model_predictions[0]) + } + return model_fn_lib.EstimatorSpec( mode=mode, predictions={ @@ -216,7 +227,8 @@ class _ModelFn(object): loss=loss, train_op=training_op, eval_metric_ops={KMeansClustering.SCORE: metrics.mean(loss)}, - training_hooks=training_hooks) + training_hooks=training_hooks, + export_outputs=export_outputs) # TODO(agarwal,ands): support sharded input.