Set export_outs in KMeans' EstimatorSpec.

PiperOrigin-RevId: 183154542
This commit is contained in:
Yutaka Leon 2018-01-24 15:45:55 -08:00 committed by TensorFlower Gardener
parent 1df1544aeb
commit 7bf8ccdb4e

View File

@ -25,6 +25,7 @@ import time
from tensorflow.contrib.factorization.python.ops import clustering_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator.export import export_output
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@ -32,6 +33,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import metrics
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.summary import summary
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
@ -207,6 +209,15 @@ class _ModelFn(object):
training_hooks.append(
_LossRelativeChangeHook(loss, self._relative_tolerance))
export_outputs = {
KMeansClustering.ALL_DISTANCES:
export_output.PredictOutput(all_distances[0]),
KMeansClustering.CLUSTER_INDEX:
export_output.PredictOutput(model_predictions[0]),
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
export_output.PredictOutput(model_predictions[0])
}
return model_fn_lib.EstimatorSpec(
mode=mode,
predictions={
@ -216,7 +227,8 @@ class _ModelFn(object):
loss=loss,
train_op=training_op,
eval_metric_ops={KMeansClustering.SCORE: metrics.mean(loss)},
training_hooks=training_hooks)
training_hooks=training_hooks,
export_outputs=export_outputs)
# TODO(agarwal,ands): support sharded input.