Set export_outs in KMeans' EstimatorSpec.
PiperOrigin-RevId: 183154542
This commit is contained in:
parent
1df1544aeb
commit
7bf8ccdb4e
@ -25,6 +25,7 @@ import time
|
|||||||
from tensorflow.contrib.factorization.python.ops import clustering_ops
|
from tensorflow.contrib.factorization.python.ops import clustering_ops
|
||||||
from tensorflow.python.estimator import estimator
|
from tensorflow.python.estimator import estimator
|
||||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||||
|
from tensorflow.python.estimator.export import export_output
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
@ -32,6 +33,7 @@ from tensorflow.python.ops import math_ops
|
|||||||
from tensorflow.python.ops import metrics
|
from tensorflow.python.ops import metrics
|
||||||
from tensorflow.python.ops import state_ops
|
from tensorflow.python.ops import state_ops
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.saved_model import signature_constants
|
||||||
from tensorflow.python.summary import summary
|
from tensorflow.python.summary import summary
|
||||||
from tensorflow.python.training import session_run_hook
|
from tensorflow.python.training import session_run_hook
|
||||||
from tensorflow.python.training import training_util
|
from tensorflow.python.training import training_util
|
||||||
@ -207,6 +209,15 @@ class _ModelFn(object):
|
|||||||
training_hooks.append(
|
training_hooks.append(
|
||||||
_LossRelativeChangeHook(loss, self._relative_tolerance))
|
_LossRelativeChangeHook(loss, self._relative_tolerance))
|
||||||
|
|
||||||
|
export_outputs = {
|
||||||
|
KMeansClustering.ALL_DISTANCES:
|
||||||
|
export_output.PredictOutput(all_distances[0]),
|
||||||
|
KMeansClustering.CLUSTER_INDEX:
|
||||||
|
export_output.PredictOutput(model_predictions[0]),
|
||||||
|
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
|
||||||
|
export_output.PredictOutput(model_predictions[0])
|
||||||
|
}
|
||||||
|
|
||||||
return model_fn_lib.EstimatorSpec(
|
return model_fn_lib.EstimatorSpec(
|
||||||
mode=mode,
|
mode=mode,
|
||||||
predictions={
|
predictions={
|
||||||
@ -216,7 +227,8 @@ class _ModelFn(object):
|
|||||||
loss=loss,
|
loss=loss,
|
||||||
train_op=training_op,
|
train_op=training_op,
|
||||||
eval_metric_ops={KMeansClustering.SCORE: metrics.mean(loss)},
|
eval_metric_ops={KMeansClustering.SCORE: metrics.mean(loss)},
|
||||||
training_hooks=training_hooks)
|
training_hooks=training_hooks,
|
||||||
|
export_outputs=export_outputs)
|
||||||
|
|
||||||
|
|
||||||
# TODO(agarwal,ands): support sharded input.
|
# TODO(agarwal,ands): support sharded input.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user